|
from src.load_data import load_dataframe |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import pandas as pd |
|
|
|
|
|
fillcolor = "#FFD21E" |
|
line_color = "#FF9D00" |
|
|
|
|
|
opacity = 0.75 |
|
|
|
|
|
categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"] |
|
|
|
def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): |
|
""" |
|
plot the index-th row of the dataframe |
|
|
|
Arguments: |
|
dataframe: a pandas DataFrame |
|
index: the index of the row we want to plot |
|
categories: the list of the metrics |
|
fillcolor: a string specifying the color to fill the area |
|
line_color: a string specifying the color of the lines in the graph |
|
""" |
|
fig = go.Figure() |
|
data = dataframe.loc[index,categories].to_numpy()*100 |
|
data = data.astype(float) |
|
|
|
data = data.round(decimals = 2) |
|
|
|
|
|
data = np.append(data, data[0]) |
|
categories_theta = categories.copy() |
|
categories_theta.append(categories[0]) |
|
model_name = dataframe.loc[index,"model_name"] |
|
|
|
|
|
fig.add_trace(go.Scatterpolar( |
|
r=data, |
|
theta=categories_theta, |
|
fill='toself', |
|
fillcolor = fillcolor, |
|
opacity = opacity, |
|
line=dict(color = line_color), |
|
name= model_name |
|
)) |
|
fig.update_layout( |
|
polar=dict( |
|
radialaxis=dict( |
|
visible=True, |
|
range=[0, 100.] |
|
)), |
|
showlegend=False |
|
) |
|
|
|
return fig |
|
|
|
def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): |
|
""" |
|
plot the results of the model named model_name row of the dataframe |
|
|
|
Arguments: |
|
dataframe: a pandas DataFrame |
|
model_name: a string stating the name of the model |
|
categories: the list of the metrics |
|
fillcolor: a string specifying the color to fill the area |
|
line_color: a string specifying the color of the lines in the graph |
|
""" |
|
fig = go.Figure() |
|
data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100 |
|
data = data.astype(float) |
|
|
|
data = data.round(decimals = 2) |
|
|
|
|
|
data = np.append(data, data[0]) |
|
categories_theta = categories.copy() |
|
categories_theta.append(categories[0]) |
|
model_name = model_name |
|
|
|
|
|
fig.add_trace(go.Scatterpolar( |
|
r=data, |
|
theta=categories_theta, |
|
fill='toself', |
|
fillcolor = fillcolor, |
|
opacity = opacity, |
|
line=dict(color = line_color), |
|
name= model_name |
|
)) |
|
fig.update_layout( |
|
polar=dict( |
|
radialaxis=dict( |
|
visible=True, |
|
range=[0, 100.] |
|
)), |
|
showlegend=False |
|
) |
|
|
|
return fig |