File size: 9,189 Bytes
d60dac5
9d25320
cebfd3c
 
 
 
95868f7
146cded
 
 
cebfd3c
b057e78
cebfd3c
0bf87c6
 
8c0bd0e
 
865ad29
95868f7
 
b057e78
3bc1a3c
95868f7
 
b057e78
3bc1a3c
95868f7
 
3bc1a3c
 
2a5c937
cebfd3c
e5e8dc0
e506679
9d25320
 
 
d60dac5
9d25320
 
99ff44d
9d25320
58f00c1
2006c2b
9d25320
 
2006c2b
9d25320
 
cebfd3c
9d25320
 
e506679
a6058bc
 
 
 
d388e51
 
 
 
a6058bc
64dbb55
9d25320
55590dd
a6058bc
 
95868f7
b13c015
95868f7
b13c015
95868f7
cf69470
b13c015
a6058bc
e6f9b68
 
8055571
8b18fd0
9d25320
55590dd
8055571
a11501e
3bc1a3c
9d25320
 
 
 
3bc1a3c
2006c2b
9d25320
e5e8dc0
 
 
 
a11501e
9d25320
 
 
 
 
 
db795fb
95868f7
 
db795fb
 
 
c2b258d
db795fb
 
 
 
42f0d1d
db795fb
 
 
95868f7
db795fb
 
95868f7
 
 
db795fb
c2b258d
 
 
e5e8dc0
a62a724
 
 
 
26db031
a62a724
e6f9b68
a29dcc0
 
 
a62a724
e6f9b68
a62a724
c2b258d
2d9e230
c2b258d
 
 
 
a62a724
c2b258d
9d25320
 
c2b258d
9d25320
e5e8dc0
95868f7
a62a724
 
c2b258d
a11501e
95868f7
9d25320
a11501e
9d25320
 
 
 
 
 
 
 
 
 
 
d60dac5
9d25320
a11501e
a62a724
 
336ff2a
d60dac5
9d25320
 
 
 
d60dac5
e5e8dc0
d60dac5
 
6c6516f
9d25320
cebfd3c
 
6c6516f
2e2d510
9d25320
cebfd3c
9d25320
2dee873
e24f6eb
a6058bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import streamlit as st
from chat_client import chat
import time
import os
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
import requests
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings 

load_dotenv()
URL_APP_SCRIPT = os.getenv('URL_APP_SCRIPT')

CHAT_BOTS = {"Mixtral 8x7B v0.1" :"mistralai/Mixtral-8x7B-Instruct-v0.1",
             "Mixtral 7B v0.2" :"mistralai/Mistral-7B-Instruct-v0.2",
            "OpenAi GTP4" : "mistralai/Mistral-7B-Instruct-v0.1", 
            "Antrophic CLAUDE" : "mistralai/Mistral-7B-Instruct-v0.1"}
options_old = {
    'Email Genitori': {'systemRole': 'Tu sei un esperto scrittore di email. Attieniti allo stile che ti ho fornito nelle instruction e inserici il contenuto richiesto. Genera il testo di una mail a partire da questo contenuto, con lo stile ricevuto in precedenza: ', 
                       'systemStyle': 'Utilizza lo stile fornito come esempio e parla in ITALIANO e firmati sempre come il Signor Preside', 
                       'instruction': URL_APP_SCRIPT + '1IxE0ic0hsWrxQod2rfh4hnKNqMC-lGT4', 
                       'RAG': False},
    'Email Colleghi': {'systemRole': 'Tu sei un esperto scrittore di email. Attieniti allo stile che ti ho fornito nelle instruction e inserici il contenuto richiesto. Genera il testo di una mail a partire da questo contenuto, con lo stile ricevuto in precedenza: ', 
                       'systemStyle': 'Utilizza lo stile fornito come esempio e parla in ITALIANO e firmati sempre come il vostro collega Preside', 
                       'instruction': URL_APP_SCRIPT + '1tEMxG0zJmmyh5PlAofKDkhbi1QGMOwPH', 
                       'RAG': False},
    'Decreti': {'systemRole': 'Tu sei il mio assistente per la ricerca documentale! Ti ho fornito una lista di documenti, devi cercare quello che ti chiedo nei documenti', 
                'systemStyle': 'Sii molto formale, sintetico e parla in ITALIANO', 
                'instruction': '',
                'RAG': True}
}

st.set_page_config(page_title="ZucchettiAI", page_icon="🤖")

def init_state() :
    if "messages" not in st.session_state:
        st.session_state.messages = []

    if "temp" not in st.session_state:
        st.session_state.temp = 0.8

    if "history" not in st.session_state:
        st.session_state.history = []

    if "top_k" not in st.session_state:
        st.session_state.top_k = 5

    if "repetion_penalty" not in st.session_state :
        st.session_state.repetion_penalty = 1

    if "chat_bot" not in st.session_state :
        st.session_state.chat_bot = "Mixtral 8x7B v0.1"

    if 'loaded_data' not in st.session_state:
        st.session_state.loaded_data = False
        
    if not st.session_state.loaded_data:
        with st.spinner('Caricamento in corso...'):
            options = requests.get(URL_APP_SCRIPT).json()
            st.session_state.options = options
            st.session_state.loaded_data = True

def sidebar():
    def retrieval_settings() :
        st.markdown("# Impostazioni Prompt")
        st.session_state.selected_option_key = st.selectbox('Azione', list(st.session_state.options.keys()) + ['Personalizzata'])
        st.session_state.selected_option = st.session_state.options.get(st.session_state.selected_option_key, {})
        st.session_state.systemRole = st.session_state.selected_option.get('systemRole', '')
        st.session_state.systemRole = st.text_area("Descrizione", st.session_state.systemRole, help='Ruolo del chatbot e descrizione dell\'azione che deve svolgere')
        st.session_state.systemStyle = st.session_state.selected_option.get('systemStyle', '')
        st.session_state.systemStyle = st.text_area("Stile", st.session_state.systemStyle, help='Descrizione dello stile utilizzato per generare il testo')
        st.session_state.instruction = st.session_state.selected_option.get('instruction', '')
        #st.session_state.instruction = st.text_area("Istruzioni", st.session_state.instruction, help='Testo di riferimento sul quale il modello si basa per generare il testo')
        
        st.session_state.rag_enabled = st.session_state.selected_option.get('tipo', '')=='RAG'
        if st.session_state.rag_enabled: 
            st.session_state.top_k = st.slider(label="Documenti da considerare", min_value=1, max_value=20, value=4, disabled=not st.session_state.rag_enabled)
        st.markdown("---")
    
    def model_settings() :
        st.markdown("# Impostazioni Modello")
        st.session_state.chat_bot = st.sidebar.radio('Seleziona Modello:', [key for key, value in CHAT_BOTS.items() ])
        st.session_state.temp = st.slider(label="Creatività", min_value=0.0, max_value=1.0, step=0.1, value=0.9)
        st.session_state.max_tokens = st.slider(label="Lunghezza Output", min_value = 64, max_value=2048, step= 32, value=1024)

    with st.sidebar:
        retrieval_settings()
        model_settings()
        st.markdown("""> **Creato da [Matteo Bergamelli] 🔗**""")

def header() :
    st.title("ZucchettiAI")
    with st.expander("Cos'è ZucchettiAI?"):
        st.info("""ZucchettiAI Chat è un ChatBot personalizzato basato su un database vettoriale, funziona secondo il principio della Generazione potenziata da Recupero (RAG). 
                   La sua funzione principale ruota attorno alla gestione di un ampio repository di manuali di Zucchetti e fornisce agli utenti risposte in linea con le loro domande. 
                   Questo approccio garantisce una risposta più precisa sulla base della richiesta degli utenti.""")

def chat_box() :
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

def formattaPrompt(prompt, systemRole, systemStyle, instruction):
    input_text = f'''
    {{
      "input": {{
          "role": "system",
          "content": "{systemRole}", 
          "style": "{systemStyle}"
      }},
      "messages": [
          {{
              "role": "instructions",
              "content": "{instruction} ({systemStyle})"
          }},
          {{
              "role": "user",
              "content": "{prompt}"
          }}
      ]
    }}
    '''
    return input_text

def gen_augmented_prompt(prompt, top_k) :   
    links = ""
    embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
    db = Chroma(persist_directory='./DB_Manuali', embedding_function=embedding) 
    docs = db.similarity_search(prompt, k=top_k)
    
    links = []
    context = ''
    NomeCartellaOriginariaDB = 'Documenti\\'
    for doc in docs:
        testo = doc.page_content.replace('\n\n', '\n')
        testo = testo.replace('***', '')
        testo = testo.replace('DOMANDA: ', '')
        testo = testo.replace('RISPOSTA:', '')
        context += testo + '\n\n\n'
        reference = doc.metadata["source"].replace(NomeCartellaOriginariaDB, '') + ' (Pag. ' + str(doc.metadata["page"]+1) + ')'
        links.append((reference, testo))
    generated_prompt = f"""
    A PARTIRE DAL SEGUENTE CONTESTO: {docs},

    ----
    RISPONDI ALLA SEGUENTE RICHIESTA: {prompt}
    """
    return context, links

def generate_chat_stream(prompt) :
    links = []
    prompt_originale = prompt
    if st.session_state.rag_enabled :
        with st.spinner("Ricerca nei Manuali...."):
            time.sleep(1)
            st.session_state.instruction, links = gen_augmented_prompt(prompt=prompt_originale, top_k=st.session_state.top_k)        
    prompt = formattaPrompt(prompt, st.session_state.systemRole, st.session_state.systemStyle, st.session_state.instruction) 
    print(prompt)
    with st.spinner("Generazione in corso...") :
        time.sleep(1)
        chat_stream = chat(prompt, st.session_state.history,chat_client=CHAT_BOTS[st.session_state.chat_bot] ,
                       temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens)    
    return chat_stream, links

def stream_handler(chat_stream, placeholder) :
    start_time = time.time()
    full_response = ''
    for chunk in chat_stream :
        if chunk.token.text!='</s>' :
            full_response += chunk.token.text
            placeholder.markdown(full_response + "▌")
    placeholder.markdown(full_response)
    return full_response

def show_source(links) :
    with st.expander("Mostra fonti") :
        for link in links:
            reference, testo = link
            st.info('##### ' + reference.replace('_', ' ') + '\n\n'+ testo)

init_state()
sidebar()
header()
chat_box()

if prompt := st.chat_input("Chatta con ZucchettiAI..."):
    st.chat_message("user").markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    chat_stream, links = generate_chat_stream(prompt)

    
    with st.chat_message("assistant"):
        placeholder = st.empty()
        full_response = stream_handler(chat_stream, placeholder)
        if st.session_state.rag_enabled :
            show_source(links)

    st.session_state.messages.append({"role": "assistant", "content": full_response})
    st.success('Generazione Completata')