antonioloison commited on
Commit
4de1a2b
·
verified ·
1 Parent(s): 19d93fe

Fix filtering crash (#9)

Browse files

- fix: fix filtering (a0e77c884d25bfffafaacdb3821865d803d417ac)

Files changed (2) hide show
  1. app.py +59 -59
  2. app/utils.py +26 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app.utils import add_rank_and_format, filter_models, get_refresh_function
4
  from data.deprecated_model_handler import DeprecatedModelHandler
5
  from data.model_handler import ModelHandler
6
 
@@ -109,8 +109,8 @@ def main():
109
 
110
  def update_data_1(metric, search_term, selected_columns):
111
  model_handler.get_vidore_data(metric)
112
- data = deprecated_model_handler.render_df(metric, benchmark_version=1)
113
- data = add_rank_and_format(data, benchmark_version=1)
114
  data = filter_models(data, search_term)
115
  if selected_columns:
116
  data = data[["Rank", "Model", "Model Size (Million Parameters)", "Average"] + selected_columns]
@@ -193,8 +193,8 @@ def main():
193
 
194
  def update_data_2(metric, search_term, selected_columns):
195
  model_handler.get_vidore_data(metric)
196
- data = deprecated_model_handler.render_df(metric, benchmark_version=2)
197
- data = add_rank_and_format(data, benchmark_version=2)
198
  data = filter_models(data, search_term)
199
  # data = remove_duplicates(data) # Add this line
200
  if selected_columns:
@@ -305,56 +305,56 @@ def main():
305
  Refer to the [ColPali paper](https://arxiv.org/abs/2407.01449) for details on metrics, tasks and models.
306
  """
307
  )
308
- datasets_columns_1 = list(deprecated_data_benchmark_1.columns[3:])
309
 
310
  with gr.Row():
311
- metric_dropdown_1 = gr.Dropdown(choices=METRICS, value=initial_metric, label="Select Metric")
312
- research_textbox_1 = gr.Textbox(
313
  placeholder="🔍 Search Models... [press enter]",
314
  label="Filter Models by Name",
315
  )
316
- column_checkboxes_1 = gr.CheckboxGroup(
317
- choices=datasets_columns_1, value=datasets_columns_1, label="Select Columns to Display"
318
  )
319
 
320
  with gr.Row():
321
- datatype_1 = ["number", "markdown"] + ["number"] * (deprecated_num_datasets_1 + 1)
322
- dataframe_1 = gr.Dataframe(deprecated_data_benchmark_1, datatype=datatype_1, type="pandas")
323
 
324
- def update_data_1(metric, search_term, selected_columns):
325
  deprecated_model_handler.get_vidore_data(metric)
326
  data = deprecated_model_handler.render_df(metric, benchmark_version=1)
327
- data = add_rank_and_format(data, benchmark_version=1)
328
  data = filter_models(data, search_term)
329
  # data = remove_duplicates(data) # Add this line
330
  if selected_columns:
331
- data = data[["Rank", "Model", "Model Size (Million Parameters)", "Average"] + selected_columns]
332
  return data
333
 
334
  with gr.Row():
335
- refresh_button_1 = gr.Button("Refresh")
336
- refresh_button_1.click(
337
- get_refresh_function(deprecated_model_handler, benchmark_version=1),
338
- inputs=[metric_dropdown_1],
339
- outputs=dataframe_1,
340
  concurrency_limit=20,
341
  )
342
 
343
  # Automatically refresh the dataframe when the dropdown value changes
344
- metric_dropdown_1.change(
345
- get_refresh_function(deprecated_model_handler, benchmark_version=1),
346
- inputs=[metric_dropdown_1],
347
- outputs=dataframe_1,
348
  )
349
- research_textbox_1.submit(
350
- lambda metric, search_term, selected_columns: update_data_1(metric, search_term, selected_columns),
351
- inputs=[metric_dropdown_1, research_textbox_1, column_checkboxes_1],
352
- outputs=dataframe_1,
353
  )
354
- column_checkboxes_1.change(
355
- lambda metric, search_term, selected_columns: update_data_1(metric, search_term, selected_columns),
356
- inputs=[metric_dropdown_1, research_textbox_1, column_checkboxes_1],
357
- outputs=dataframe_1,
358
  )
359
 
360
  gr.Markdown(
@@ -398,38 +398,38 @@ def main():
398
  Refer to the [ColPali paper](https://arxiv.org/abs/2407.01449) for details on metrics and models.
399
  """
400
  )
401
- datasets_columns_2 = list(deprecated_data_benchmark_2.columns[3:])
402
 
403
  with gr.Row():
404
- metric_dropdown_2 = gr.Dropdown(choices=METRICS, value=initial_metric, label="Select Metric")
405
- research_textbox_2 = gr.Textbox(
406
  placeholder="🔍 Search Models... [press enter]",
407
  label="Filter Models by Name",
408
  )
409
- column_checkboxes_2 = gr.CheckboxGroup(
410
- choices=datasets_columns_2, value=datasets_columns_2, label="Select Columns to Display"
411
  )
412
 
413
  with gr.Row():
414
- datatype_2 = ["number", "markdown"] + ["number"] * (deprecated_num_datasets_2 + 1)
415
- dataframe_2 = gr.Dataframe(deprecated_data_benchmark_2, datatype=datatype_2, type="pandas")
416
 
417
- def update_data_2(metric, search_term, selected_columns):
418
  deprecated_model_handler.get_vidore_data(metric)
419
  data = deprecated_model_handler.render_df(metric, benchmark_version=2)
420
- data = add_rank_and_format(data, benchmark_version=2)
421
  data = filter_models(data, search_term)
422
  # data = remove_duplicates(data) # Add this line
423
  if selected_columns:
424
- data = data[["Rank", "Model", "Model Size (Million Parameters)", "Average"] + selected_columns]
425
  return data
426
 
427
  with gr.Row():
428
- refresh_button_2 = gr.Button("Refresh")
429
- refresh_button_2.click(
430
- get_refresh_function(deprecated_model_handler, benchmark_version=2),
431
- inputs=[metric_dropdown_2],
432
- outputs=dataframe_2,
433
  concurrency_limit=20,
434
  )
435
 
@@ -442,20 +442,20 @@ def main():
442
  )
443
 
444
  # Automatically refresh the dataframe when the dropdown value changes
445
- metric_dropdown_2.change(
446
- get_refresh_function(deprecated_model_handler, benchmark_version=2),
447
- inputs=[metric_dropdown_2],
448
- outputs=dataframe_2,
449
  )
450
- research_textbox_2.submit(
451
- lambda metric, search_term, selected_columns: update_data_2(metric, search_term, selected_columns),
452
- inputs=[metric_dropdown_2, research_textbox_2, column_checkboxes_2],
453
- outputs=dataframe_2,
454
  )
455
- column_checkboxes_2.change(
456
- lambda metric, search_term, selected_columns: update_data_2(metric, search_term, selected_columns),
457
- inputs=[metric_dropdown_2, research_textbox_2, column_checkboxes_2],
458
- outputs=dataframe_2,
459
  )
460
 
461
  gr.Markdown(
 
1
  import gradio as gr
2
 
3
+ from app.utils import add_rank_and_format, filter_models, get_refresh_function, deprecated_get_refresh_function
4
  from data.deprecated_model_handler import DeprecatedModelHandler
5
  from data.model_handler import ModelHandler
6
 
 
109
 
110
  def update_data_1(metric, search_term, selected_columns):
111
  model_handler.get_vidore_data(metric)
112
+ data = model_handler.render_df(metric, benchmark_version=1)
113
+ data = add_rank_and_format(data, benchmark_version=1, selected_columns=selected_columns)
114
  data = filter_models(data, search_term)
115
  if selected_columns:
116
  data = data[["Rank", "Model", "Model Size (Million Parameters)", "Average"] + selected_columns]
 
193
 
194
  def update_data_2(metric, search_term, selected_columns):
195
  model_handler.get_vidore_data(metric)
196
+ data = model_handler.render_df(metric, benchmark_version=2)
197
+ data = add_rank_and_format(data, benchmark_version=2, selected_columns=selected_columns)
198
  data = filter_models(data, search_term)
199
  # data = remove_duplicates(data) # Add this line
200
  if selected_columns:
 
305
  Refer to the [ColPali paper](https://arxiv.org/abs/2407.01449) for details on metrics, tasks and models.
306
  """
307
  )
308
+ deprecated_datasets_columns_1 = list(deprecated_data_benchmark_1.columns[3:])
309
 
310
  with gr.Row():
311
+ deprecated_metric_dropdown_1 = gr.Dropdown(choices=METRICS, value=initial_metric, label="Select Metric")
312
+ deprecated_research_textbox_1 = gr.Textbox(
313
  placeholder="🔍 Search Models... [press enter]",
314
  label="Filter Models by Name",
315
  )
316
+ deprecated_column_checkboxes_1 = gr.CheckboxGroup(
317
+ choices=deprecated_datasets_columns_1, value=deprecated_datasets_columns_1, label="Select Columns to Display"
318
  )
319
 
320
  with gr.Row():
321
+ deprecated_datatype_1 = ["number", "markdown"] + ["number"] * (deprecated_num_datasets_1 + 1)
322
+ deprecated_dataframe_1 = gr.Dataframe(deprecated_data_benchmark_1, datatype=deprecated_datatype_1, type="pandas")
323
 
324
+ def deprecated_update_data_1(metric, search_term, selected_columns):
325
  deprecated_model_handler.get_vidore_data(metric)
326
  data = deprecated_model_handler.render_df(metric, benchmark_version=1)
327
+ data = add_rank_and_format(data, benchmark_version=1, selected_columns=selected_columns)
328
  data = filter_models(data, search_term)
329
  # data = remove_duplicates(data) # Add this line
330
  if selected_columns:
331
+ data = data[["Rank", "Model", "Average"] + selected_columns]
332
  return data
333
 
334
  with gr.Row():
335
+ deprecated_refresh_button_1 = gr.Button("Refresh")
336
+ deprecated_refresh_button_1.click(
337
+ deprecated_get_refresh_function(deprecated_model_handler, benchmark_version=1),
338
+ inputs=[deprecated_metric_dropdown_1],
339
+ outputs=deprecated_dataframe_1,
340
  concurrency_limit=20,
341
  )
342
 
343
  # Automatically refresh the dataframe when the dropdown value changes
344
+ deprecated_metric_dropdown_1.change(
345
+ deprecated_get_refresh_function(deprecated_model_handler, benchmark_version=1),
346
+ inputs=[deprecated_metric_dropdown_1],
347
+ outputs=deprecated_dataframe_1,
348
  )
349
+ deprecated_research_textbox_1.submit(
350
+ lambda metric, search_term, selected_columns: deprecated_update_data_1(metric, search_term, selected_columns),
351
+ inputs=[deprecated_metric_dropdown_1, deprecated_research_textbox_1, deprecated_column_checkboxes_1],
352
+ outputs=deprecated_dataframe_1,
353
  )
354
+ deprecated_column_checkboxes_1.change(
355
+ lambda metric, search_term, selected_columns: deprecated_update_data_1(metric, search_term, selected_columns),
356
+ inputs=[deprecated_metric_dropdown_1, deprecated_research_textbox_1, deprecated_column_checkboxes_1],
357
+ outputs=deprecated_dataframe_1,
358
  )
359
 
360
  gr.Markdown(
 
398
  Refer to the [ColPali paper](https://arxiv.org/abs/2407.01449) for details on metrics and models.
399
  """
400
  )
401
+ deprecated_datasets_columns_2 = list(deprecated_data_benchmark_2.columns[3:])
402
 
403
  with gr.Row():
404
+ deprecated_metric_dropdown_2 = gr.Dropdown(choices=METRICS, value=initial_metric, label="Select Metric")
405
+ deprecated_research_textbox_2 = gr.Textbox(
406
  placeholder="🔍 Search Models... [press enter]",
407
  label="Filter Models by Name",
408
  )
409
+ deprecated_column_checkboxes_2 = gr.CheckboxGroup(
410
+ choices=deprecated_datasets_columns_2, value=deprecated_datasets_columns_2, label="Select Columns to Display"
411
  )
412
 
413
  with gr.Row():
414
+ deprecated_datatype_2 = ["number", "markdown"] + ["number"] * (deprecated_num_datasets_2 + 1)
415
+ deprecated_dataframe_2 = gr.Dataframe(deprecated_data_benchmark_2, datatype=deprecated_datatype_2, type="pandas")
416
 
417
+ def deprecated_update_data_2(metric, search_term, selected_columns):
418
  deprecated_model_handler.get_vidore_data(metric)
419
  data = deprecated_model_handler.render_df(metric, benchmark_version=2)
420
+ data = add_rank_and_format(data, benchmark_version=2, selected_columns=selected_columns)
421
  data = filter_models(data, search_term)
422
  # data = remove_duplicates(data) # Add this line
423
  if selected_columns:
424
+ data = data[["Rank", "Model", "Average"] + selected_columns]
425
  return data
426
 
427
  with gr.Row():
428
+ deprecated_refresh_button_2 = gr.Button("Refresh")
429
+ deprecated_refresh_button_2.click(
430
+ deprecated_get_refresh_function(deprecated_model_handler, benchmark_version=2),
431
+ inputs=[deprecated_metric_dropdown_2],
432
+ outputs=deprecated_dataframe_2,
433
  concurrency_limit=20,
434
  )
435
 
 
442
  )
443
 
444
  # Automatically refresh the dataframe when the dropdown value changes
445
+ deprecated_metric_dropdown_2.change(
446
+ deprecated_get_refresh_function(deprecated_model_handler, benchmark_version=2),
447
+ inputs=[deprecated_metric_dropdown_2],
448
+ outputs=deprecated_dataframe_2,
449
  )
450
+ deprecated_research_textbox_2.submit(
451
+ lambda metric, search_term, selected_columns: deprecated_update_data_2(metric, search_term, selected_columns),
452
+ inputs=[deprecated_metric_dropdown_2, deprecated_research_textbox_2, deprecated_column_checkboxes_2],
453
+ outputs=deprecated_dataframe_2,
454
  )
455
+ deprecated_column_checkboxes_2.change(
456
+ lambda metric, search_term, selected_columns: deprecated_update_data_2(metric, search_term, selected_columns),
457
+ inputs=[deprecated_metric_dropdown_2, deprecated_research_textbox_2, deprecated_column_checkboxes_2],
458
+ outputs=deprecated_dataframe_2,
459
  )
460
 
461
  gr.Markdown(
app/utils.py CHANGED
@@ -17,20 +17,23 @@ def make_clickable_model(model_name, link=None):
17
  return f'<a target="_blank" style="text-decoration: underline" href="{link}">{desanitized_model_name}</a>'
18
 
19
 
20
- def add_rank(df, benchmark_version=1):
21
  df.fillna(0.0, inplace=True)
22
- cols_to_rank = [
23
- col
24
- for col in df.columns
25
- if col
26
- not in [
27
- "Model",
28
- "Model Size (Million Parameters)",
29
- "Memory Usage (GB, fp32)",
30
- "Embedding Dimensions",
31
- "Max Tokens",
 
 
32
  ]
33
- ]
 
34
 
35
  if len(cols_to_rank) == 1:
36
  df.sort_values(cols_to_rank[0], ascending=False, inplace=True)
@@ -45,10 +48,10 @@ def add_rank(df, benchmark_version=1):
45
  return df
46
 
47
 
48
- def add_rank_and_format(df, benchmark_version=1):
49
  df = df.reset_index()
50
  df = df.rename(columns={"index": "Model"})
51
- df = add_rank(df, benchmark_version)
52
  df["Model"] = df["Model"].apply(make_clickable_model)
53
  # df = remove_duplicates(df)
54
  return df
@@ -71,6 +74,15 @@ def get_refresh_function(model_handler, benchmark_version):
71
 
72
  return _refresh
73
 
 
 
 
 
 
 
 
 
 
74
 
75
  def filter_models(data, search_term):
76
  if search_term:
 
17
  return f'<a target="_blank" style="text-decoration: underline" href="{link}">{desanitized_model_name}</a>'
18
 
19
 
20
+ def add_rank(df, benchmark_version=1, selected_columns=None):
21
  df.fillna(0.0, inplace=True)
22
+ if selected_columns is None:
23
+ cols_to_rank = [
24
+ col
25
+ for col in df.columns
26
+ if col
27
+ not in [
28
+ "Model",
29
+ "Model Size (Million Parameters)",
30
+ "Memory Usage (GB, fp32)",
31
+ "Embedding Dimensions",
32
+ "Max Tokens",
33
+ ]
34
  ]
35
+ else:
36
+ cols_to_rank = selected_columns
37
 
38
  if len(cols_to_rank) == 1:
39
  df.sort_values(cols_to_rank[0], ascending=False, inplace=True)
 
48
  return df
49
 
50
 
51
+ def add_rank_and_format(df, benchmark_version=1, selected_columns=None):
52
  df = df.reset_index()
53
  df = df.rename(columns={"index": "Model"})
54
+ df = add_rank(df, benchmark_version, selected_columns)
55
  df["Model"] = df["Model"].apply(make_clickable_model)
56
  # df = remove_duplicates(df)
57
  return df
 
74
 
75
  return _refresh
76
 
77
+ def deprecated_get_refresh_function(model_handler, benchmark_version):
78
+ def _refresh(metric):
79
+ model_handler.get_vidore_data(metric)
80
+ data_task_category = model_handler.render_df(metric, benchmark_version)
81
+ df = add_rank_and_format(data_task_category, benchmark_version)
82
+ return df
83
+
84
+ return _refresh
85
+
86
 
87
  def filter_models(data, search_term):
88
  if search_term: