test-interface / utils_app.py
tommasodelorenzo's picture
Upload folder using huggingface_hub
d46cc41 verified
import supabase
import gradio as gr
from typing import Union, Optional
import os
from datetime import datetime
import pytz
from supabase_memory import SupabaseChatMessageHistory
from dotenv import load_dotenv
load_dotenv()
supabase_client = supabase.create_client(
os.environ.get("SUPABASE_URL"),
os.environ.get("SUPABASE_KEY"),
)
def _get_user_id(request: gr.Request) -> str:
response = supabase_client \
.table("users") \
.select("id") \
.eq("name", request.username) \
.execute()
return response.data[0]["id"]
def _delete_empty_sessions() -> None:
supabase_client \
.table(os.environ["SESSIONS_TABLE_NAME"]) \
.delete() \
.is_("updated_at", "null") \
.execute()
pass
def _get_session_ids(user_id:str) -> list:
_delete_empty_sessions()
response = supabase_client \
.table(os.environ["SESSIONS_TABLE_NAME"]) \
.select("id") \
.eq("user_id", user_id) \
.order('created_at', desc=True) \
.execute()
if len(response.data) == 0:
session_id = _create_session(user_id)
return [session_id]
return [row["id"] for row in response.data]
def _get_latest_message_id(session_id:str) -> str:
response = supabase_client \
.table(os.environ["MESSAGES_TABLE_NAME"]) \
.select("id, score, message, error_log") \
.eq("chat_id", session_id) \
.is_("error_log", "null") \
.eq("message->>type", "ai") \
.order('created_at', desc=True) \
.limit(1) \
.execute()
if len(response.data) == 0:
return None, None
return response.data[0]["id"], response.data[0]["score"]
def _get_session_messages(session_id:str) -> list:
memory = SupabaseChatMessageHistory(
session_id = session_id,
client = supabase_client,
table_name = os.environ.get("MESSAGES_TABLE_NAME"),
session_name = "chat",
)
messages = memory.messages
return [(messages[i].content, messages[i+1].content) for i in range(0, len(messages), 2)]
def _get_users() -> list:
response = supabase_client \
.table("users") \
.select("name, password") \
.execute()
return response.data
def _update_session(
session_id:str,
metadata:Optional[dict] = None
):
update_dict = {
"updated_at": datetime.now(pytz.utc).isoformat(),
}
if metadata is not None:
update_dict["metadata"] = metadata
supabase_client.table(os.environ["SESSIONS_TABLE_NAME"]) \
.update(update_dict) \
.eq('id', session_id) \
.execute()
def _score_chosen(
session_id:str,
score:Optional[str]
):
print("score chosen...", score)
allow_inputs = score is not None
if allow_inputs:
message_id, _ = _get_latest_message_id(session_id)
response = supabase_client \
.table(os.environ.get("MESSAGES_TABLE_NAME")) \
.update({
"score": int(score),
}) \
.eq('id', message_id) \
.execute()
_update_session(response.data[0]["chat_id"])
return (
gr.Column(visible=score is not None), # comment_column
gr.Textbox(interactive=allow_inputs, placeholder = "Dai un voto alla risposta precedente prima di continuare la conversazione o iniziarne una nuova" if not allow_inputs else None), # input
gr.Button(interactive=allow_inputs) # new_chat
)
def _comment_submitted(
session_id:str,
comment:str
):
message_id, score = _get_latest_message_id(session_id)
response = supabase_client \
.table(os.environ.get("MESSAGES_TABLE_NAME")) \
.update(
{
"comment": comment
}
) \
.eq('id', message_id) \
.execute()
_update_session(response.data[0]["chat_id"])
pass
def _clear_comments():
return (
gr.Column(visible=False), # comment_column
None, # comment
gr.Radio(visible=False), # score,
gr.Textbox(interactive=True, value=None), # input
gr.Button(interactive=True) # new_chat
)
def _create_session(
user_id:str,
) -> str:
response = supabase_client.table(os.environ["SESSIONS_TABLE_NAME"]) \
.insert(
{
"user_id": user_id,
}
).execute()
return response.data[0]["id"]
def _new_chat(user_id:str):
session_ids = _get_session_ids(user_id)
session_id = _create_session(user_id)
return (
[], # chatbot
gr.Textbox(visible=True, value=None), # input
gr.Column(visible=False), # comment_column
None, # comment
gr.Radio(visible=False), # score
gr.Dropdown(choices=[session_id] + session_ids, value=session_id, interactive=True), # session_id
)
def _get_session_metadata(session_id:str) -> dict:
response = supabase_client \
.table(os.environ["SESSIONS_TABLE_NAME"]) \
.select("metadata") \
.eq("id", session_id) \
.execute()
return response.data[0]["metadata"]
def _session_id_selected(session_id):
message_id, score = _get_latest_message_id(session_id)
is_empty_session = message_id is None
voted = score is not None
allow_inputs = voted or is_empty_session
print(f"session_id_selected..., allow_inputs: {allow_inputs}, voted: {voted}, is_empty_session: {is_empty_session}")
return (
_get_session_messages(session_id), # chatbot
gr.Textbox(
interactive=allow_inputs,
value=None,
placeholder = "Dai un voto alla risposta precedente prima di continuare la conversazione o iniziarne una nuova" if not allow_inputs else None), # input
gr.Column(visible=False), # comment_column
None, # comment
gr.Radio(visible = not allow_inputs), # score
)
def _load_interface(request: gr.Request):
user_id = _get_user_id(request)
session_ids = _get_session_ids(user_id)
print(f"loading interface...")
return (
user_id, # user_id
gr.Dropdown(choices = session_ids, value = session_ids[0]), # session_id
)
def _get_link(doc:dict) -> str:
MAIN_URL = "https://def.finanze.it/DocTribFrontend"
type = doc["type"]
if type == "Prassi":
row = supabase_client.table("praxis") \
.select("def_id") \
.eq("id", doc["supabase_praxis_id"]) \
.limit(1) \
.execute().data[0]
return doc["title"], f"{MAIN_URL}/getPrassiDetail.do?id=%7B{row['def_id'].upper()}%7D"
if type == "Dottrina":
return doc["title"], None
if type == "Norma":
row = supabase_client.table("articles") \
.select("def_id, name, metadata, norms(def_id)") \
.eq("id", doc["supabase_id"]) \
.limit(1) \
.execute().data[0]
params = {
"ACTION": "getArticolo",
"id": "{" + row["norms"]["def_id"].upper() + "}",
"articolo": row["name"].replace(" ", "%20"),
"codiceOrdinamento": row["metadata"]["codOrdinamento"],
}
return doc["title"], f"{MAIN_URL}/getAttoNormativoDetail.do?"+"&".join([f"{k}={v}" for k,v in params.items()])
def _add_footnote_description(answer:str, docs:list):
"""
For each markdown footnote placeholder "[^uuid]" in the answer, adds a description "[^uuid]: [title](link) of the footnote at the end of the answer. The title and link are retrieved from the docs list matching the uuid of the footnote using something like matching_doc = next(
(doc for doc in docs if doc["supabase_id"] == uuid), None
)
"""
import re
footnotes = re.findall(r"\[[^\]]+\]", answer)
for footnote in footnotes:
matching_doc = next(
(doc for doc in docs if doc["supabase_id"] == footnote.replace("[^", "").replace("]", "")), None
)
if matching_doc is not None:
title, link = _get_link(matching_doc)
if link is not None:
answer += f"\n{footnote}: [{title}]({link})"
else:
answer += f"\n{footnote}: {title}"
return answer
def _replace_markdown_links(answer:str, docs:list):
"""
Replaces markdown link placeholders (text)[uuid] with actual links.
Args:
answer_dict: A dictionary containing the 'answer' with markdown links and 'docs' with relevant data.
get_link_func: A function that takes a doc and returns the actual link.
Returns:
The modified answer string with replaced links.
"""
import re
# Regular expression to match markdown links with UUID placeholders
link_pattern = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
def replace_link(match):
text = match.group(1)
uuid = match.group(2)
# Find the corresponding doc based on the UUID
matching_doc = next(
(doc for doc in docs if doc["supabase_id"] == uuid), None
)
if matching_doc:
link = _get_link(matching_doc)
if link is not None:
return f"[{text}]({link})"
else:
return text
else:
# Handle cases where the doc is not found
return match.group(0) # Keep the original link unchanged
# Replace links in the answer string
modified_answer = link_pattern.sub(replace_link, answer)
return modified_answer
def _create_users(n=10):
# Create n users with random names (random 6 hex) and passwords (random 12 hex)
import random
import string
def random_string(length):
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
for i in range(n):
supabase_client.table("users").insert({
"name": random_string(6),
"password": random_string(12),
}).execute()