Spaces:
Running
Running
import supabase | |
import gradio as gr | |
from typing import Union, Optional | |
import os | |
from datetime import datetime | |
import pytz | |
from supabase_memory import SupabaseChatMessageHistory | |
from dotenv import load_dotenv | |
load_dotenv() | |
supabase_client = supabase.create_client( | |
os.environ.get("SUPABASE_URL"), | |
os.environ.get("SUPABASE_KEY"), | |
) | |
def _get_user_id(request: gr.Request) -> str: | |
response = supabase_client \ | |
.table("users") \ | |
.select("id") \ | |
.eq("name", request.username) \ | |
.execute() | |
return response.data[0]["id"] | |
def _delete_empty_sessions() -> None: | |
supabase_client \ | |
.table(os.environ["SESSIONS_TABLE_NAME"]) \ | |
.delete() \ | |
.is_("updated_at", "null") \ | |
.execute() | |
pass | |
def _get_session_ids(user_id:str) -> list: | |
_delete_empty_sessions() | |
response = supabase_client \ | |
.table(os.environ["SESSIONS_TABLE_NAME"]) \ | |
.select("id") \ | |
.eq("user_id", user_id) \ | |
.order('created_at', desc=True) \ | |
.execute() | |
if len(response.data) == 0: | |
session_id = _create_session(user_id) | |
return [session_id] | |
return [row["id"] for row in response.data] | |
def _get_latest_message_id(session_id:str) -> str: | |
response = supabase_client \ | |
.table(os.environ["MESSAGES_TABLE_NAME"]) \ | |
.select("id, score, message, error_log") \ | |
.eq("chat_id", session_id) \ | |
.is_("error_log", "null") \ | |
.eq("message->>type", "ai") \ | |
.order('created_at', desc=True) \ | |
.limit(1) \ | |
.execute() | |
if len(response.data) == 0: | |
return None, None | |
return response.data[0]["id"], response.data[0]["score"] | |
def _get_session_messages(session_id:str) -> list: | |
memory = SupabaseChatMessageHistory( | |
session_id = session_id, | |
client = supabase_client, | |
table_name = os.environ.get("MESSAGES_TABLE_NAME"), | |
session_name = "chat", | |
) | |
messages = memory.messages | |
return [(messages[i].content, messages[i+1].content) for i in range(0, len(messages), 2)] | |
def _get_users() -> list: | |
response = supabase_client \ | |
.table("users") \ | |
.select("name, password") \ | |
.execute() | |
return response.data | |
def _update_session( | |
session_id:str, | |
metadata:Optional[dict] = None | |
): | |
update_dict = { | |
"updated_at": datetime.now(pytz.utc).isoformat(), | |
} | |
if metadata is not None: | |
update_dict["metadata"] = metadata | |
supabase_client.table(os.environ["SESSIONS_TABLE_NAME"]) \ | |
.update(update_dict) \ | |
.eq('id', session_id) \ | |
.execute() | |
def _score_chosen( | |
session_id:str, | |
score:Optional[str] | |
): | |
print("score chosen...", score) | |
allow_inputs = score is not None | |
if allow_inputs: | |
message_id, _ = _get_latest_message_id(session_id) | |
response = supabase_client \ | |
.table(os.environ.get("MESSAGES_TABLE_NAME")) \ | |
.update({ | |
"score": int(score), | |
}) \ | |
.eq('id', message_id) \ | |
.execute() | |
_update_session(response.data[0]["chat_id"]) | |
return ( | |
gr.Column(visible=score is not None), # comment_column | |
gr.Textbox(interactive=allow_inputs, placeholder = "Dai un voto alla risposta precedente prima di continuare la conversazione o iniziarne una nuova" if not allow_inputs else None), # input | |
gr.Button(interactive=allow_inputs) # new_chat | |
) | |
def _comment_submitted( | |
session_id:str, | |
comment:str | |
): | |
message_id, score = _get_latest_message_id(session_id) | |
response = supabase_client \ | |
.table(os.environ.get("MESSAGES_TABLE_NAME")) \ | |
.update( | |
{ | |
"comment": comment | |
} | |
) \ | |
.eq('id', message_id) \ | |
.execute() | |
_update_session(response.data[0]["chat_id"]) | |
pass | |
def _clear_comments(): | |
return ( | |
gr.Column(visible=False), # comment_column | |
None, # comment | |
gr.Radio(visible=False), # score, | |
gr.Textbox(interactive=True, value=None), # input | |
gr.Button(interactive=True) # new_chat | |
) | |
def _create_session( | |
user_id:str, | |
) -> str: | |
response = supabase_client.table(os.environ["SESSIONS_TABLE_NAME"]) \ | |
.insert( | |
{ | |
"user_id": user_id, | |
} | |
).execute() | |
return response.data[0]["id"] | |
def _new_chat(user_id:str): | |
session_ids = _get_session_ids(user_id) | |
session_id = _create_session(user_id) | |
return ( | |
[], # chatbot | |
gr.Textbox(visible=True, value=None), # input | |
gr.Column(visible=False), # comment_column | |
None, # comment | |
gr.Radio(visible=False), # score | |
gr.Dropdown(choices=[session_id] + session_ids, value=session_id, interactive=True), # session_id | |
) | |
def _get_session_metadata(session_id:str) -> dict: | |
response = supabase_client \ | |
.table(os.environ["SESSIONS_TABLE_NAME"]) \ | |
.select("metadata") \ | |
.eq("id", session_id) \ | |
.execute() | |
return response.data[0]["metadata"] | |
def _session_id_selected(session_id): | |
message_id, score = _get_latest_message_id(session_id) | |
is_empty_session = message_id is None | |
voted = score is not None | |
allow_inputs = voted or is_empty_session | |
print(f"session_id_selected..., allow_inputs: {allow_inputs}, voted: {voted}, is_empty_session: {is_empty_session}") | |
return ( | |
_get_session_messages(session_id), # chatbot | |
gr.Textbox( | |
interactive=allow_inputs, | |
value=None, | |
placeholder = "Dai un voto alla risposta precedente prima di continuare la conversazione o iniziarne una nuova" if not allow_inputs else None), # input | |
gr.Column(visible=False), # comment_column | |
None, # comment | |
gr.Radio(visible = not allow_inputs), # score | |
) | |
def _load_interface(request: gr.Request): | |
user_id = _get_user_id(request) | |
session_ids = _get_session_ids(user_id) | |
print(f"loading interface...") | |
return ( | |
user_id, # user_id | |
gr.Dropdown(choices = session_ids, value = session_ids[0]), # session_id | |
) | |
def _get_link(doc:dict) -> str: | |
MAIN_URL = "https://def.finanze.it/DocTribFrontend" | |
type = doc["type"] | |
if type == "Prassi": | |
row = supabase_client.table("praxis") \ | |
.select("def_id") \ | |
.eq("id", doc["supabase_praxis_id"]) \ | |
.limit(1) \ | |
.execute().data[0] | |
return doc["title"], f"{MAIN_URL}/getPrassiDetail.do?id=%7B{row['def_id'].upper()}%7D" | |
if type == "Dottrina": | |
return doc["title"], None | |
if type == "Norma": | |
row = supabase_client.table("articles") \ | |
.select("def_id, name, metadata, norms(def_id)") \ | |
.eq("id", doc["supabase_id"]) \ | |
.limit(1) \ | |
.execute().data[0] | |
params = { | |
"ACTION": "getArticolo", | |
"id": "{" + row["norms"]["def_id"].upper() + "}", | |
"articolo": row["name"].replace(" ", "%20"), | |
"codiceOrdinamento": row["metadata"]["codOrdinamento"], | |
} | |
return doc["title"], f"{MAIN_URL}/getAttoNormativoDetail.do?"+"&".join([f"{k}={v}" for k,v in params.items()]) | |
def _add_footnote_description(answer:str, docs:list): | |
""" | |
For each markdown footnote placeholder "[^uuid]" in the answer, adds a description "[^uuid]: [title](link) of the footnote at the end of the answer. The title and link are retrieved from the docs list matching the uuid of the footnote using something like matching_doc = next( | |
(doc for doc in docs if doc["supabase_id"] == uuid), None | |
) | |
""" | |
import re | |
footnotes = re.findall(r"\[[^\]]+\]", answer) | |
for footnote in footnotes: | |
matching_doc = next( | |
(doc for doc in docs if doc["supabase_id"] == footnote.replace("[^", "").replace("]", "")), None | |
) | |
if matching_doc is not None: | |
title, link = _get_link(matching_doc) | |
if link is not None: | |
answer += f"\n{footnote}: [{title}]({link})" | |
else: | |
answer += f"\n{footnote}: {title}" | |
return answer | |
def _replace_markdown_links(answer:str, docs:list): | |
""" | |
Replaces markdown link placeholders (text)[uuid] with actual links. | |
Args: | |
answer_dict: A dictionary containing the 'answer' with markdown links and 'docs' with relevant data. | |
get_link_func: A function that takes a doc and returns the actual link. | |
Returns: | |
The modified answer string with replaced links. | |
""" | |
import re | |
# Regular expression to match markdown links with UUID placeholders | |
link_pattern = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") | |
def replace_link(match): | |
text = match.group(1) | |
uuid = match.group(2) | |
# Find the corresponding doc based on the UUID | |
matching_doc = next( | |
(doc for doc in docs if doc["supabase_id"] == uuid), None | |
) | |
if matching_doc: | |
link = _get_link(matching_doc) | |
if link is not None: | |
return f"[{text}]({link})" | |
else: | |
return text | |
else: | |
# Handle cases where the doc is not found | |
return match.group(0) # Keep the original link unchanged | |
# Replace links in the answer string | |
modified_answer = link_pattern.sub(replace_link, answer) | |
return modified_answer | |
def _create_users(n=10): | |
# Create n users with random names (random 6 hex) and passwords (random 12 hex) | |
import random | |
import string | |
def random_string(length): | |
return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) | |
for i in range(n): | |
supabase_client.table("users").insert({ | |
"name": random_string(6), | |
"password": random_string(12), | |
}).execute() | |