Spaces:
Sleeping
Sleeping
import streamlit as st | |
from chat_client import chat | |
import time | |
import os | |
from dotenv import load_dotenv | |
from sentence_transformers import SentenceTransformer | |
#from langchain_community.vectorstores import Chroma | |
#from langchain_community.embeddings import HuggingFaceEmbeddings | |
load_dotenv() | |
CHAT_BOTS = {"Mixtral 8x7B v0.1" :"mistralai/Mixtral-8x7B-Instruct-v0.1"} | |
SYSTEM_PROMPT = ["Sei BonsiAI e mi aiuterai nelle mie richieste (Parla in ITALIANO)", "Esatto, sono BonsiAI. Di cosa hai bisogno?"] | |
IDENTITY_CHANGE = ["Sei BonsiAI da ora in poi!", "Certo farò del mio meglio"] | |
options = { | |
'Email Genitori': {'text': 'Scrivi il testo per una mail XXXX su questo stile.', 'description': 'Descrizione aggiuntiva per Email Genitori'}, | |
'Email Colleghi': {'text': 'Scrivi il testo per una mail XXXX su questo stile.', 'description': 'Descrizione aggiuntiva per Email Colleghi'}, | |
'Decreti': {'text': 'Cerca testo dei decreti!', 'description': 'Descrizione aggiuntiva per Decreti'} | |
} | |
#persist_directory1 = './DB_Decreti' | |
#embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
#db = Chroma(persist_directory=persist_directory1, embedding_function=embedding) | |
#NumeroDocumenti = 10 | |
#query = 'Come funziona la generazione delle PDA' | |
#result = db.similarity_search(query, k=NumeroDocumenti) | |
st.set_page_config(page_title="BonsiAI", page_icon="🤖") | |
def gen_augmented_prompt(prompt, top_k) : | |
context = "" | |
links = "" | |
generated_prompt = f""" | |
A PARTIRE DAL SEGUENTE CONTESTO: {context}, | |
---- | |
RISPONDI ALLA SEGUENTE RICHIESTA: {prompt} | |
""" | |
return generated_prompt, links | |
def init_state() : | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "temp" not in st.session_state: | |
st.session_state.temp = 0.8 | |
if "history" not in st.session_state: | |
st.session_state.history = [SYSTEM_PROMPT] | |
if "top_k" not in st.session_state: | |
st.session_state.top_k = 5 | |
if "repetion_penalty" not in st.session_state : | |
st.session_state.repetion_penalty = 1 | |
if "rag_enabled" not in st.session_state : | |
st.session_state.rag_enabled = True | |
if "chat_bot" not in st.session_state : | |
st.session_state.chat_bot = "Mixtral 8x7B v0.1" | |
def sidebar() : | |
def retrieval_settings() : | |
st.markdown("# Impostazioni Azioni") | |
st.session_state.selected_option_key = st.selectbox('Azione', list(options.keys()) + ['+ Aggiungi']) | |
st.session_state.selected_option = options.get(st.session_state.selected_option_key, {}) | |
st.session_state.selected_option_text = st.session_state.selected_option.get('text', '') | |
st.session_state.option_text = st.text_area("Testo Azione", st.session_state.selected_option_text) | |
st.session_state.selected_option_description = st.session_state.selected_option.get('description', '') | |
if st.session_state.selected_option_key == 'Decreti': | |
st.session_state.rag_enabled = st.toggle("Cerca nel DB Vettoriale", value=True) | |
st.session_state.top_k = st.slider(label="Documenti da ricercare", min_value=1, max_value=20, value=4, disabled=not st.session_state.rag_enabled) | |
st.markdown("---") | |
def model_settings() : | |
st.markdown("# Impostazioni Modello") | |
st.session_state.chat_bot = st.sidebar.radio('Seleziona Modello:', [key for key, value in CHAT_BOTS.items() ]) | |
st.session_state.temp = st.slider(label="Creatività", min_value=0.0, max_value=1.0, step=0.1, value=0.9) | |
st.session_state.max_tokens = st.slider(label="Lunghezza Output", min_value = 64, max_value=2048, step= 32, value=512) | |
with st.sidebar: | |
retrieval_settings() | |
model_settings() | |
st.markdown("""> **Creato da [Matteo Script] 🔗**""") | |
def header() : | |
st.title("BonsiAI") | |
with st.expander("Cos'è BonsiAI?"): | |
st.info("""BonsiAI Chat è un ChatBot personalizzato basato su un database vettoriale, funziona secondo il principio della Generazione potenziata da Recupero (RAG). | |
La sua funzione principale ruota attorno alla gestione di un ampio repository di documenti BonsiAI e fornisce agli utenti risposte in linea con le loro domande. | |
Questo approccio garantisce una risposta più precisa sulla base della richiesta degli utenti.""") | |
def chat_box() : | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
def formattaPrompt(prompt, systemRole, systemStyle, instruction): | |
input_text = f''' | |
{{ | |
"input": {{ | |
"role": "system", | |
"content": "{systemRole}", | |
"style": "{systemStyle}" | |
}}, | |
"messages": [ | |
{{ | |
"role": "instructions", | |
"content": "{instruction} "("{systemStyle}")" | |
}}, | |
{{ | |
"role": "user", | |
"content": "{input}" | |
}} | |
] | |
}} | |
''' | |
return input_text | |
def generate_chat_stream(prompt) : | |
links = [] | |
if st.session_state.rag_enabled : | |
with st.spinner("Ricerca nei documenti...."): | |
time.sleep(2) | |
prompt, links = gen_augmented_prompt(prompt=prompt, top_k=st.session_state.top_k) | |
with st.spinner("Generazione in corso...") : | |
time.sleep(2) | |
chat_stream = chat(prompt, st.session_state.history,chat_client=CHAT_BOTS[st.session_state.chat_bot] , | |
temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens) | |
return chat_stream, links | |
def stream_handler(chat_stream, placeholder) : | |
start_time = time.time() | |
full_response = '' | |
for chunk in chat_stream : | |
if chunk.token.text!='</s>' : | |
full_response += chunk.token.text | |
placeholder.markdown(full_response + "▌") | |
placeholder.markdown(full_response) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
total_tokens_processed = len(full_response.split()) | |
tokens_per_second = total_tokens_processed // elapsed_time | |
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25 | |
col1, col2, col3 = st.columns(3) | |
with col1 : | |
st.write(f"**{elapsed_time} secondi**") | |
with col2 : | |
st.write(f"**{int(len_response)} tokens generati**") | |
with col3 : | |
st.write(f"**{tokens_per_second} token/secondi**") | |
return full_response | |
def show_source(links) : | |
with st.expander("Mostra fonti") : | |
for i, link in enumerate(links) : | |
st.info(f"{link}") | |
init_state() | |
sidebar() | |
header() | |
chat_box() | |
if prompt := st.chat_input("Chatta con BonsiAI..."): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
chat_stream, links = generate_chat_stream(prompt) | |
with st.chat_message("assistant"): | |
placeholder = st.empty() | |
full_response = stream_handler(chat_stream, placeholder) | |
if st.session_state.rag_enabled : | |
show_source(links) | |
st.session_state.history.append([prompt, full_response]) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
st.success('Generazione Completata') |