import streamlit as st import pandas as pd import re import ast import io import os from langchain_core.messages import HumanMessage, AIMessage, ToolMessage from pathlib import Path import uuid import warnings warnings.filterwarnings("ignore") #################################################################### ### FUNCTIONS ### #################################################################### @st.cache_data(show_spinner=True) def initializations(): st.session_state.question = "" st.session_state.file_dataset = "./data/gaia_subset.csv" st.session_state.file_evaluations = "./data/gaia_evals.csv" st.session_state.gaia = True st.session_state.file_lib = "./data/lib.md" st.session_state.file_sidebar = "./data/gaia_sidebar.txt" st.session_state.dfk = str(uuid.uuid4()) # @st.cache_data(show_spinner=True) def get_dataset(dataset_file): return pd.read_csv(dataset_file, sep='µ', engine='python') # @st.cache_data(show_spinner=True) def get_evaluations(eval_file): def set_eval(answer1, answer2): answer1 = re.sub(r'\.$', '', answer1.lower()).replace(', ', ',') answer2 = re.sub(r'\.$', '', answer2.lower()).replace(', ', ',') return answer1 == answer2 df = pd.read_csv(eval_file, sep='µ', engine='python') df = df.merge(st.session_state.df_dataset[['task_id', 'question', 'file_url', 'answer']], on='task_id', how='left') list_labels = pd.unique(df['label']) list_questions = pd.unique(df['question']) df['eval'] = df.apply(lambda r: set_eval(str(r['submitted_answer']), str(r['answer'])), axis=1) df_pivot = df.pivot(index=['task_id','question'], columns='label', values=['eval','submitted_answer','messages']) df_reset = df_pivot.reindex(columns=list_labels, level=1).reset_index() df_reset['question'] = pd.Categorical(df_reset['question'], categories=list_questions, ordered=True) df_eval = df_reset.sort_values('question') df_synth = df.pivot(index='question', columns='label', values='eval') \ .reindex(columns=list_labels) \ .reindex(pd.unique(df_eval['question'])) totaux = df_synth.sum(axis=0) df_perf = totaux.reset_index().T df_perf.columns = df_perf.iloc[0] df_perf = df_perf.iloc[1:] df_perf.loc["Nb correct"] = totaux df_perf.loc["% correct"] = totaux *100 / len(df_eval) df_perf = df_perf.iloc[1:] return df_eval, df_synth, df_perf, list_labels # @st.cache_data(show_spinner=True) def get_lib(lib_file): lib = '' if isinstance(lib_file, str): lib = Path(lib_file).read_text(encoding="utf-8") else: lib = lib_file.read().decode("utf-8") return lib # @st.cache_data(show_spinner=True) def get_sidebar(sidebar_file): if isinstance(sidebar_file, str): with open(sidebar_file, "r", encoding="utf-8") as f: lignes = f.readlines() else: stringio = io.StringIO(sidebar_file.read().decode("utf-8")) lignes = stringio.readlines() return lignes # def parse_messages_from_string(messages_str): messages = [] status = True try: messages_match = re.search(r"'messages': \[(.*)\]", messages_str, re.DOTALL) messages_content = messages_match.group(1) message_splits = re.findall(r'(HumanMessage\(.*?\)|AIMessage\(.*?\)|ToolMessage\(.*?\))(?=, HumanMessage\(|, AIMessage\(|, ToolMessage\(|$)', messages_content, re.DOTALL) for msg_str in message_splits: # Identifier le type de message if msg_str.startswith('HumanMessage'): msg_type = 'HumanMessage' elif msg_str.startswith('AIMessage'): msg_type = 'AIMessage' elif msg_str.startswith('ToolMessage'): msg_type = 'ToolMessage' else: continue # Type inconnu, passer au suivant # Extraire les arguments du constructeur args_str = msg_str[len(msg_type)+1:-1] # Supprimer 'TypeMessage(' et ')' # Convertir les arguments en dictionnaire # Remplacer les paires clé=valeur par des paires 'clé': valeur args_str = re.sub(r'(\w+)=', r'"\1":', args_str) try: args = ast.literal_eval('{' + args_str + '}') # Créer l'objet de message approprié if msg_type == 'HumanMessage': message = HumanMessage(**args) elif msg_type == 'AIMessage': message = AIMessage(**args) elif msg_type == 'ToolMessage': message = ToolMessage(**args) else: continue messages.append(message) except Exception as e: message = HumanMessage(f"*** Error parsing message: {e}") messages.append(message) message = HumanMessage(f"*** See the original list of messages below") messages.append(message) status = False print(f"Error parsing message: {e}") continue except Exception as e: print(f"Erreur lors de l'analyse du messageparse_message_from_string: {e}") finally: return messages, status # def get_details(): dfkey = st.session_state.dfk if len(st.session_state[dfkey]) > 0: if len(st.session_state[dfkey]["selection"]["rows"]): num_raw = st.session_state[dfkey]["selection"]["rows"][0] df_eval = st.session_state.df_eval st.session_state.question = df_eval.iloc[num_raw].question.squeeze() for i in range(0, len(st.session_state.list_labels)): with list_tabs[i].chat_message("ai"): if df_eval.iloc[num_raw].eval[i]: st.markdown(str(df_eval.iloc[num_raw].submitted_answer[i])+" "+ ":green-badge[:material/check: Correct]") else: st.markdown(str(df_eval.iloc[num_raw].submitted_answer[i]) + " " + ":orange-badge[⚠️ Needs review]") messages, status = parse_messages_from_string(df_eval.iloc[num_raw].messages[i]) c = st.container(border=True) c.markdown("### Message history:") c.text("\n".join(m.pretty_repr() for m in messages)) if not status: c.text(df_eval.iloc[num_raw].messages[i]) #print("\n".join(m.pretty_repr() for m in messages)) # def save_uploaded_file(uploaded_file, folder="data"): os.makedirs(folder, exist_ok=True) save_path = os.path.join(folder, uploaded_file.name) with open(save_path, "wb") as f: f.write(uploaded_file.getbuffer()) return save_path # #################################################################### ### MAIN ### #################################################################### #--- Initializations st.set_page_config(page_title='Agents evaluation',layout="wide", initial_sidebar_state="auto") initializations() if 'question' not in st.session_state: st.session_state.question = "" if 'file_dataset' not in st.session_state: st.session_state.file_dataset = "./data/gaia_subset.csv" if 'file_evaluations' not in st.session_state: st.session_state.file_evaluations = "./data/gaia_evals.csv" if 'gaia' not in st.session_state: st.session_state.gaia = True if 'file_lib' not in st.session_state: st.session_state.file_lib = "./data/lib.md" if 'file_sidebar' not in st.session_state: st.session_state.file_sidebar = "./data/gaia_sidebar.txt" if 'dfk' not in st.session_state: st.session_state.dfk = str(uuid.uuid4()) #--- Set title if st.session_state.gaia: col1, col2 = st.columns([0.4, 0.6], vertical_alignment="center") col1.image("thumbnail.jpg") col2.markdown("