File size: 9,100 Bytes
643e1b9
 
aac5496
635d434
643e1b9
 
7cdeec9
643e1b9
ceac55e
ac12a64
b7aed3a
b8c06a5
465bc79
a665d4d
c4e3b46
b799c08
 
9806fee
8a48625
5edf283
98e8fcb
ecc789c
f462884
 
 
98e8fcb
f462884
 
 
465bc79
0dde5f1
 
 
974c8b8
98e8fcb
bbe77cb
6130d38
ceac55e
 
d9717fd
ceac55e
 
 
 
 
815a803
d9717fd
 
98e8fcb
63c9ed5
da42015
7cdeec9
 
 
 
 
98e8fcb
cd97913
0875680
 
 
 
7cdeec9
 
 
e76e7f8
7cdeec9
 
 
 
 
bcb5f66
7cdeec9
 
 
 
8a48625
 
 
 
 
 
 
d6b1dd7
8a48625
 
 
862f6a2
ceac55e
f462884
 
 
7cdeec9
f0608de
98e8fcb
cd97913
b8c06a5
 
a056bd8
9b05e50
b8c06a5
 
ceac55e
286f7a4
8bce767
 
b8c06a5
 
 
 
8bce767
 
 
af7cdf0
 
 
643e1b9
 
708da42
bcb5f66
708da42
 
b8c06a5
08c9e9f
b8c06a5
08c9e9f
5592cea
08c9e9f
5592cea
ceac55e
d1a1a8f
 
8bce767
bcb5f66
08c9e9f
643e1b9
c4e3b46
4b526d4
b8c06a5
8bce767
84c9b6e
 
7cdeec9
862f6a2
 
d482417
3ef1210
 
9a2efed
3ef1210
 
 
 
 
ceac55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f16b8b5
3ef1210
d482417
f16b8b5
6b5f2b5
da42015
8bce767
a056bd8
62efafa
8bce767
 
 
 
 
 
fc1e0c1
8bce767
 
6b5f2b5
 
862f6a2
6224872
7857d7b
 
d482417
7857d7b
815a803
7857d7b
 
 
 
45fc882
975ddfc
d482417
66eb39c
d482417
 
 
 
975ddfc
d482417
ceac55e
 
 
c4e3b46
 
8b64b0a
86b68c0
7857d7b
 
b96d1a1
96e470d
8bce767
96e470d
f7aeb1e
66eb39c
 
 
96e470d
24bbb51
 
96e470d
07ea5ca
96e470d
 
 
 
 
07ea5ca
96e470d
 
0328982
8238f47
643e1b9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
#from interface import GemmaLLMInterface
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import Settings, ServiceContext, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, PromptTemplate, load_index_from_storage, StorageContext
from llama_index.core.node_parser import SentenceSplitter
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
from typing import Iterator, List, Any
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.llms import ChatMessage, MessageRole , CompletionResponse
from IPython.display import Markdown, display
from langchain_huggingface import HuggingFaceEmbeddings

#from llama_index import LangchainEmbedding, ServiceContext
#from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceInferenceAPI, HuggingFaceLLM
from dotenv import load_dotenv

import logging
import sys


logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))


huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)


"""huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto", 
    torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    token=True)

tokenizer= AutoTokenizer.from_pretrained("google/gemma-2b-it")
model.tokenizer = tokenizer
model.eval()"""


system_prompt="""
You are a Q&A assistant. Your goal is to answer questions as
accurately as possible based on the instructions and context provided.
"""

load_dotenv()


os.environ['TOKENIZERS_PARALLELISM'] = 'false'


llm = HuggingFaceLLM(
    context_window=4096,
    max_new_tokens=256,
    generate_kwargs={"temperature": 0.1, "do_sample": True},
    system_prompt=system_prompt,
    tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
    model_name="meta-llama/Llama-2-7b-chat-hf",
    device_map="auto",
    # loading model in 8bit for reducing memory
    model_kwargs={"torch_dtype": torch.float16 }
)

embed_model= HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

Settings.llm = llm
Settings.embed_model = embed_model
#Settings.node_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20, paragraph_separator="\n\n")
Settings.num_output = 512
Settings.context_window = 3900


documents = SimpleDirectoryReader('./data').load_data()

nodes = SentenceSplitter(chunk_size=512, chunk_overlap=20, paragraph_separator="\n\n").get_nodes_from_documents(documents)
# Build the vector store index from the nodes


# what models will be used by LlamaIndex:
#Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
#Settings.embed_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
#Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
#Settings.llm = GemmaLLMInterface()


documents_paths = {
    'blockchain': 'data/blockchainprova.txt',
    'metaverse': 'data/metaverseprova.txt',
    'payment': 'data/paymentprova.txt'
}


global session_state
session_state = {"index": False,
                 "documents_loaded": False, 
                 "document_db": None, 
                 "original_message": None, 
                 "clarification": False}

PERSIST_DIR = "./db"
os.makedirs(PERSIST_DIR, exist_ok=True)


ISTR = "In italiano, chiedi molto brevemente se la domanda si riferisce agli 'Osservatori Blockchain', 'Osservatori Payment' oppure 'Osservatori Metaverse'."

############################---------------------------------

# Get the parser
"""parser = SentenceSplitter.from_defaults(
                chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
            )
def build_index(path: str):
    # Load documents from a file
    documents = SimpleDirectoryReader(input_files=[path]).load_data()
    # Parse the documents into nodes
    nodes = parser.get_nodes_from_documents(documents)
    # Build the vector store index from the nodes
    index = VectorStoreIndex(nodes)
    
    #storage_context = StorageContext.from_defaults()
    #index.storage_context.persist(persist_dir=PERSIST_DIR)
    
    return index"""



@spaces.GPU(duration=15)
def handle_query(query_str: str, 
                 chat_history: list[tuple[str, str]]) -> Iterator[str]:
  

    #index= build_index("data/blockchainprova.txt")
    index = VectorStoreIndex(nodes, show_progress = True)
    
    
    conversation: List[ChatMessage] = []
    for user, assistant in chat_history:
      conversation.extend([
      ChatMessage(role=MessageRole.USER, content=user),
      ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
                    ]
                )
    
    """if not session_state["index"]:
        
        matched_path = None
        words = query_str.lower()
        for key, path in documents_paths.items():
            if key in words:
                matched_path = path
                break
        if matched_path:
            index = build_index(matched_path)
            gr.Info("index costruito con la path sulla base della query")
            session_state["index"] = True
            
        else: ## CHIEDI CHIARIMENTO
          
            conversation.append(ChatMessage(role=MessageRole.SYSTEM, content=ISTR))
            
            index = build_index("data/blockchainprova.txt")
            gr.Info("index costruito con richiesta di chiarimento")
            
                  
    else:
        
        index = build_index(matched_path)
        #storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
        #index = load_index_from_storage(storage_context)
        gr.Info("index is true")"""    

    try:
        
        memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
        
        """chat_engine = index.as_chat_engine(
        chat_mode="condense_plus_context",
        memory=memory,
        similarity_top_k=3, 
        response_mode= "tree_summarize", #Good for summarization purposes

        context_prompt = (
        "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
        " Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
        " Ecco i documenti rilevanti per il contesto:\n"
        "{context_str}"
        "\nIstruzione: Usa la cronologia della chat, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
          ),
        verbose=False,
        )"""
        
        print("chat engine..")
        gr.Info("chat engine..")
        chat_engine = index.as_chat_engine(
        chat_mode="context",
        similarity_top_k=3, 
        memory=memory,
        context_prompt=(
            "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
            " Usa la cronologia della chat, o il contesto fornito, per interagire e aiutare l'utente a rispondere alla sua domanda."
        ),
    )
        
        """retriever = index.as_retriever(similarity_top_k=3)
        # Let's test it out
        relevant_chunks = relevant_chunks = retriever.retrieve(query_str)
        print(f"Found: {len(relevant_chunks)} relevant chunks")
        for idx, chunk in enumerate(relevant_chunks):
            info_message += f"{idx + 1}) {chunk.text[:64]}...\n"
            print(info_message)
            gr.Info(info_message)"""
            
        
        #prompts_dict = chat_engine.get_prompts()
        #display_prompt_dict(prompts_dict)

        
        #chat_engine.reset()
        outputs = []
        #response = query_engine.query(query_str) 
        response = chat_engine.stream_chat(query_str, chat_history=conversation)

        sources = []  # Use a list to collect multiple sources if present
        #response = chat_engine.chat(query_str)

        for token in response.response_gen:
            if token.startswith("assistant:"):
                # Remove the "assistant:" prefix
                outputs.append(token[len("assistant:"):])
                print(f"Generated token: {token}")
                yield "".join(outputs)
                #yield CompletionResponse(text=''.join(outputs), delta=token)

        """if sources:
            sources_str = ", ".join(sources)
            outputs.append(f"Fonti utilizzate: {sources_str}")
        else:
            outputs.append("Nessuna fonte specifica utilizzata.")

        yield "".join(outputs)"""


    except Exception as e:
        yield f"Error processing query: {str(e)}"