TabPFN commited on
Commit
d213847
·
1 Parent(s): 0f0db0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -40
app.py CHANGED
@@ -4,22 +4,25 @@ sys.path.insert(0, tabpfn_path) # our submodule of the TabPFN repo (at 045c84002
4
  from TabPFN.scripts.transformer_prediction_interface import TabPFNClassifier
5
 
6
  import numpy as np
 
7
  import pandas as pd
8
  import torch
9
  import gradio as gr
10
  import openml
11
-
 
 
12
 
13
  def compute(table: np.array):
14
- vfunc = np.vectorize(lambda s: len(s))
15
  non_empty_row_mask = (vfunc(table).sum(1) != 0)
16
  table = table[non_empty_row_mask]
17
- empty_mask = table == ''
18
  empty_inds = np.where(empty_mask)
19
  if not len(empty_inds[0]):
20
- return "**Please leave at least one field blank for prediction.**", None
21
  if not np.all(empty_inds[1][0] == empty_inds[1]):
22
- return "**Please only leave fields of one column blank for prediction.**", None
23
  y_column = empty_inds[1][0]
24
  eval_lines = empty_inds[0]
25
 
@@ -32,66 +35,117 @@ def compute(table: np.array):
32
 
33
  y_train = train_table[:, y_column]
34
  except ValueError:
35
- return "**Please only add numbers (to the inputs) or leave fields empty.**", None
36
 
37
  classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu')
38
  classifier.fit(x_train, y_train)
39
  y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True)
40
 
41
  # print(file, type(file))
42
- out_table = table.copy().astype(str)
43
- out_table[eval_lines, y_column] = [f"{y_e} (p={p_e:.2f})" for y_e, p_e in zip(y_eval, p_eval)]
44
- return None, out_table
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
- def upload_file(file):
 
48
  if file.name.endswith('.arff'):
49
  dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name)
50
  X_, _, categorical_indicator_, attribute_names_ = dataset.get_data(
51
  dataset_format="array"
52
  )
53
  df = pd.DataFrame(X_, columns=attribute_names_)
54
- return df
55
  elif file.name.endswith('.csv') or file.name.endswith('.data'):
56
- df = pd.read_csv(file.name, header=None)
57
- df.columns = np.arange(len(df.columns))
58
- return df
59
-
60
-
61
- example = \
62
- [
63
- [1, 2, 1],
64
- [2, 1, 1],
65
- [1, 1, 1],
66
- [2, 2, 2],
67
- [3, 4, 2],
68
- [3, 2, 2],
69
- [2, 3, '']
70
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  with gr.Blocks() as demo:
73
- gr.Markdown("""This demo allows you to play with the **TabPFN**.
74
- The TabPFN will classify the values for all empty cells in the label column.
75
- Please, provide everything but the label column as numeric values.
76
- You can also upload datasets to fill the table automatically.
77
- """)
78
- inp_table = gr.DataFrame(type='numpy', value=example, headers=[''] * 3)
79
- upload_file('iris.csv')
80
 
81
- btn = gr.Button("Predict Empty Table Cells")
82
- btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table])
 
 
 
 
83
 
84
- out_text = gr.Markdown()
85
- out_table = gr.DataFrame()
86
 
87
  examples = gr.Examples(examples=['iris.csv', 'balance-scale.arff'],
88
  inputs=[inp_file],
89
  outputs=[inp_table],
90
  fn=upload_file,
91
  cache_examples=True)
92
- inp_file = gr.File(
93
- label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')
94
 
95
  inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)
96
 
97
- demo.launch()
 
4
  from TabPFN.scripts.transformer_prediction_interface import TabPFNClassifier
5
 
6
  import numpy as np
7
+ from pathlib import Path
8
  import pandas as pd
9
  import torch
10
  import gradio as gr
11
  import openml
12
+ import os
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib.colors import ListedColormap
15
 
16
  def compute(table: np.array):
17
+ vfunc = np.vectorize(lambda s: len(str(s)))
18
  non_empty_row_mask = (vfunc(table).sum(1) != 0)
19
  table = table[non_empty_row_mask]
20
+ empty_mask = table == '(predict)'
21
  empty_inds = np.where(empty_mask)
22
  if not len(empty_inds[0]):
23
+ return "⚠️ **ERROR: Please leave at least one field blank for prediction.**", None, None
24
  if not np.all(empty_inds[1][0] == empty_inds[1]):
25
+ return "⚠️ **Please only leave fields of one column blank for prediction.**", None, None
26
  y_column = empty_inds[1][0]
27
  eval_lines = empty_inds[0]
28
 
 
35
 
36
  y_train = train_table[:, y_column]
37
  except ValueError:
38
+ return "⚠️ **Please only add numbers (to the inputs) or leave fields empty.**", None, None
39
 
40
  classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu')
41
  classifier.fit(x_train, y_train)
42
  y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True)
43
 
44
  # print(file, type(file))
45
+ out_table = pd.DataFrame(table.copy().astype(str))
46
+ out_table.iloc[eval_lines, y_column] = [f"{y_e} (p={p_e:.2f})" for y_e, p_e in zip(y_eval, p_eval)]
47
+ out_table = out_table.iloc[eval_lines, :]
48
+ out_table.columns = headers
49
+
50
+ # PLOTTING
51
+ fig = plt.figure(figsize=(10,10))
52
+ ax = fig.add_subplot(111)
53
+ cm = plt.cm.RdBu
54
+ cm_bright = ListedColormap(["#FF0000", "#0000FF"])
55
+
56
+ # Plot the training points
57
+ vfunc = np.vectorize(lambda x : np.where(classifier.classes_ == x)[0])
58
+ y_train_index = vfunc(y_train)
59
+ y_train_index = y_train_index == 0
60
+ y_train = y_train_index
61
+ #x_train = x_train[y_train_index <= 1]
62
+ #y_train = y_train[y_train_index <= 1]
63
+ #y_train_index = y_train_index[y_train_index <= 1]
64
+
65
+ ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train_index, cmap=cm_bright)
66
+
67
+ classifier = TabPFNClassifier(device='cpu', base_path='/home/hollmann/',
68
+ model_string=model_string, N_ensemble_configurations=1
69
+ , no_preprocess_mode=False, i=i, feature_shift_decoder=True, multiclass_decoder='permutation')
70
+ classifier.fit(x_train[:, 0:2], y_train)
71
+
72
+ DecisionBoundaryDisplay.from_estimator(
73
+ classifier, x_train[:, 0:2], alpha=0.6, ax=ax, eps=2.0, grid_resolution=100, response_method="predict_proba"
74
+ )
75
+ plt.xlabel(headers[0])
76
+ plt.ylabel(headers[1])
77
+
78
+ return None, out_table, fig
79
 
80
 
81
+ def upload_file(file, remove_entries=10):
82
+ global headers
83
  if file.name.endswith('.arff'):
84
  dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name)
85
  X_, _, categorical_indicator_, attribute_names_ = dataset.get_data(
86
  dataset_format="array"
87
  )
88
  df = pd.DataFrame(X_, columns=attribute_names_)
89
+ headers = df.columns
90
  elif file.name.endswith('.csv') or file.name.endswith('.data'):
91
+ df = pd.read_csv(file.name, header='infer')
92
+ headers = df.columns
93
+ #df.columns = np.arange(len(df.columns))
94
+
95
+ df.iloc[0:remove_entries, -1] = '(predict)'
96
+ return df
97
+
98
+
99
+ def update_table(table):
100
+ global headers
101
+ table = pd.DataFrame(table)
102
+ vfunc = np.vectorize(lambda s: len(str(s)))
103
+ non_empty_row_mask = (vfunc(table).sum(1) != 0)
104
+ table = table[non_empty_row_mask]
105
+ empty_mask = table == ''
106
+ empty_inds = np.where(empty_mask)
107
+ if not len(empty_inds[0]):
108
+ return table
109
+
110
+ y_column = empty_inds[1][0]
111
+ eval_lines = empty_inds[0]
112
+
113
+ table.iloc[eval_lines, y_column] = '(predict)'
114
+ table.columns = headers
115
+
116
+ return table
117
+
118
+ headers = []
119
+
120
+ gr.Markdown("""This demo allows you to play with the **TabPFN**.
121
+ The TabPFN will classify the values for all empty cells in the label column.
122
+ Please, provide everything but the label column as numeric values.
123
+ You can also upload datasets to fill the table automatically.
124
+ """)
125
 
126
  with gr.Blocks() as demo:
127
+ with gr.Tab("Enter Input Data"):
128
+ inp_file = gr.File(
129
+ label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')
130
+
131
+ inp_table = gr.DataFrame(type='numpy', value=upload_file(Path('iris.csv'), remove_entries=10), headers=[''] * 3)
132
+ inp_table.change(fn=update_table, inputs=inp_table, outputs=inp_table)
 
133
 
134
+ with gr.Tab("Run Predictions"):
135
+
136
+ btn = gr.Button("Start")
137
+ out_text = gr.Markdown()
138
+ out_table = gr.DataFrame()
139
+ out_plot = gr.Plot()
140
 
141
+ btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table, out_plot])
 
142
 
143
  examples = gr.Examples(examples=['iris.csv', 'balance-scale.arff'],
144
  inputs=[inp_file],
145
  outputs=[inp_table],
146
  fn=upload_file,
147
  cache_examples=True)
 
 
148
 
149
  inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)
150
 
151
+ demo.launch()