srush HF staff commited on
Commit
eac8316
·
1 Parent(s): 47fab39

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # + tags=["hide_inp"]
2
+ desc = """
3
+ ### Table
4
+
5
+ Example of extracting tables from a textual document. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/table.ipynb)
6
+
7
+ """
8
+ # -
9
+
10
+ # $
11
+ import pandas as pd
12
+ from minichain import prompt, Mock, show, OpenAIStream
13
+ import minichain
14
+ import json
15
+ import gradio as gr
16
+ import requests
17
+
18
+ rotowire = requests.get("https://raw.githubusercontent.com/srush/text2table/main/data.json").json()
19
+ names = {
20
+ '3-pointer percentage': 'FG3_PCT',
21
+ '3-pointers attempted': 'FG3A',
22
+ '3-pointers made': 'FG3M',
23
+ 'Assists': 'AST',
24
+ 'Blocks': 'BLK',
25
+ 'Field goal percentage': 'FG_PCT',
26
+ 'Field goals attempted': 'FGA',
27
+ 'Field goals made': 'FGM',
28
+ 'Free throw percentage': 'FT_PCT',
29
+ 'Free throws attempted': 'FTA',
30
+ 'Free throws made': 'FTM',
31
+ 'Minutes played': 'MIN',
32
+ 'Personal fouls': 'PF',
33
+ 'Points': 'PTS',
34
+ 'Rebounds': 'REB',
35
+ 'Rebounds (Defensive)': 'DREB',
36
+ 'Rebounds (Offensive)': 'OREB',
37
+ 'Steals': 'STL',
38
+ 'Turnovers': 'TO'
39
+ }
40
+ # Convert an example to dataframe
41
+ def to_df(d):
42
+ players = {player for v in d.values() if v is not None for player, _ in v.items()}
43
+ lookup = {k: {a: b for a, b in v.items()} for k,v in d.items()}
44
+ rows = [dict(**{"player": p}, **{k: "_" if p not in lookup.get(k, []) else lookup[k][p] for k in names.keys()})
45
+ for p in players]
46
+ return pd.DataFrame.from_dict(rows).astype("str").sort_values(axis=0, by="player", ignore_index=True).transpose()
47
+
48
+
49
+ # Make few shot examples
50
+ few_shot_examples = 2
51
+ examples = []
52
+ for i in range(few_shot_examples):
53
+ examples.append({"input": rotowire[i][1],
54
+ "output": to_df(rotowire[i][0][1]).transpose().set_index("player").to_csv(sep="\t")})
55
+
56
+ @prompt(OpenAIStream(),
57
+ template_file="table.pmpt.txt",
58
+ block_output=gr.HTML,
59
+ stream=True)
60
+ def extract(model, passage, typ):
61
+ state = []
62
+ out = ""
63
+ for token in model.stream(dict(player_keys=names.items(), examples=examples, passage=passage, type=typ)):
64
+ out += token
65
+ html = "<table><tr><td>" + out.replace("\t", "</td><td>").replace("\n", "</td></tr><tr><td>") + "</td></td></table>"
66
+ yield html
67
+ yield html
68
+
69
+
70
+
71
+ def run(query):
72
+ return extract(query, "Player")
73
+
74
+ # $
75
+
76
+ gradio = show(run,
77
+ examples = [rotowire[i][1] for i in range(50, 55)],
78
+ subprompts=[extract],
79
+ code=open("table.py", "r").read().split("$")[1].strip().strip("#").strip(),
80
+ out_type="markdown"
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ gradio.queue().launch()