|
import streamlit as st |
|
from chat_client import chat |
|
import time |
|
import os |
|
from dotenv import load_dotenv |
|
from sentence_transformers import SentenceTransformer |
|
load_dotenv() |
|
|
|
CHAT_BOTS = {"Mixtral 8x7B v0.1" :"mistralai/Mixtral-8x7B-Instruct-v0.1"} |
|
SYSTEM_PROMPT = ["Sei BonsiAI e mi aiuterai nelle mie richieste (Parla in ITALIANO)", "Esatto, sono BonsiAI. Di cosa hai bisogno?"] |
|
IDENTITY_CHANGE = ["Sei BonsiAI da ora in poi!", "Certo farò del mio meglio"] |
|
|
|
st.set_page_config(page_title="BonsiAI", page_icon="🤖") |
|
|
|
def gen_augmented_prompt(prompt, top_k) : |
|
context = "" |
|
links = "" |
|
generated_prompt = f""" |
|
A PARTIRE DAL SEGUENTE CONTESTO: {context}, |
|
|
|
---- |
|
RISPONDI ALLA SEGUENTE RICHIESTA: {prompt} |
|
""" |
|
return generated_prompt, links |
|
|
|
def init_state() : |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if "temp" not in st.session_state: |
|
st.session_state.temp = 0.8 |
|
|
|
if "history" not in st.session_state: |
|
st.session_state.history = [SYSTEM_PROMPT] |
|
|
|
if "top_k" not in st.session_state: |
|
st.session_state.top_k = 5 |
|
|
|
if "repetion_penalty" not in st.session_state : |
|
st.session_state.repetion_penalty = 1 |
|
|
|
if "rag_enabled" not in st.session_state : |
|
st.session_state.rag_enabled = True |
|
|
|
if "chat_bot" not in st.session_state : |
|
st.session_state.chat_bot = "Mixtral 8x7B v0.1" |
|
|
|
def sidebar() : |
|
def retrieval_settings() : |
|
st.markdown("# Impostazioni Documenti") |
|
st.session_state.rag_enabled = st.toggle("Cerca nel DB Vettoriale", value=True) |
|
st.session_state.top_k = st.slider(label="Documenti da ricercare", |
|
min_value=1, max_value=20, value=4, disabled=not st.session_state.rag_enabled) |
|
st.markdown("---") |
|
|
|
def model_settings() : |
|
st.markdown("# Impostazioni Modello") |
|
st.session_state.chat_bot = st.sidebar.radio('Seleziona Modello:', [key for key, value in CHAT_BOTS.items() ]) |
|
st.session_state.temp = st.slider(label="Creatività", min_value=0.0, max_value=1.0, step=0.1, value=0.9) |
|
st.session_state.max_tokens = st.slider(label="Lunghezza Output", min_value = 64, max_value=2048, step= 32, value=512) |
|
st.session_state.repetion_penalty = st.slider(label="Penalità Ripetizione", min_value=0., max_value=1., step=0.1, value=1. ) |
|
|
|
with st.sidebar: |
|
retrieval_settings() |
|
model_settings() |
|
st.markdown("""> **Creato da [Matteo Script] 🔗**""") |
|
|
|
def header() : |
|
st.title("BonsiAI") |
|
with st.expander("Cos'è BonsiAI?"): |
|
st.info("""BonsiAI Chat è un ChatBot personalizzato basato su un database vettoriale, funziona secondo il principio della Generazione potenziata da Recupero (RAG). |
|
La sua funzione principale ruota attorno alla gestione di un ampio repository di documenti BonsiAI e fornisce agli utenti risposte in linea con le loro domande. |
|
Questo approccio garantisce una risposta più precisa sulla base della richiesta degli utenti.""") |
|
|
|
def chat_box() : |
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
def feedback_buttons() : |
|
is_visible = True |
|
def click_handler() : |
|
is_visible = False |
|
if is_visible : |
|
col1, col2 = st.columns(2) |
|
with col1 : |
|
st.button("👍 Soddisfatto", on_click = click_handler,type="primary") |
|
with col2 : |
|
st.button("👎 Deluso", on_click=click_handler, type="secondary") |
|
|
|
def generate_chat_stream(prompt) : |
|
links = [] |
|
if st.session_state.rag_enabled : |
|
with st.spinner("Ricerca nei documenti...."): |
|
prompt, links = gen_augmented_prompt(prompt=prompt, top_k=st.session_state.top_k) |
|
with st.spinner("Generazione in corso...") : |
|
chat_stream = chat(prompt, st.session_state.history,chat_client=CHAT_BOTS[st.session_state.chat_bot] , |
|
temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens) |
|
return chat_stream, links |
|
|
|
def stream_handler(chat_stream, placeholder) : |
|
start_time = time.time() |
|
full_response = '' |
|
|
|
for chunk in chat_stream : |
|
if chunk.token.text!='</s>' : |
|
full_response += chunk.token.text |
|
placeholder.markdown(full_response + "▌") |
|
placeholder.markdown(full_response) |
|
|
|
end_time = time.time() |
|
elapsed_time = end_time - start_time |
|
total_tokens_processed = len(full_response.split()) |
|
tokens_per_second = total_tokens_processed // elapsed_time |
|
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25 |
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1 : |
|
st.write(f"**{tokens_per_second} token/secondi**") |
|
|
|
with col2 : |
|
st.write(f"**{int(len_response)} tokens generati**") |
|
|
|
return full_response |
|
|
|
def show_source(links) : |
|
with st.expander("Mostra fonti") : |
|
for i, link in enumerate(links) : |
|
st.info(f"{link}") |
|
|
|
init_state() |
|
sidebar() |
|
header() |
|
chat_box() |
|
|
|
if prompt := st.chat_input("Chatta con BonsiAI..."): |
|
st.chat_message("user").markdown(prompt) |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
chat_stream, links = generate_chat_stream(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
placeholder = st.empty() |
|
full_response = stream_handler(chat_stream, placeholder) |
|
if st.session_state.rag_enabled : |
|
show_source(links) |
|
|
|
st.session_state.history.append([prompt, full_response]) |
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |