Futuresony commited on
Commit
a493165
·
verified ·
1 Parent(s): 6adc50d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +556 -0
app.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import PyPDF2
4
+ import logging
5
+ import torch
6
+ import threading
7
+ import time
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ TextIteratorStreamer,
12
+ StoppingCriteria,
13
+ StoppingCriteriaList,
14
+ )
15
+ from transformers import logging as hf_logging
16
+ import spaces
17
+ from llama_index.core import (
18
+ StorageContext,
19
+ VectorStoreIndex,
20
+ load_index_from_storage,
21
+ Document as LlamaDocument,
22
+ )
23
+ from llama_index.core import Settings
24
+ from llama_index.core.node_parser import (
25
+ HierarchicalNodeParser,
26
+ get_leaf_nodes,
27
+ get_root_nodes,
28
+ )
29
+ from llama_index.core.retrievers import AutoMergingRetriever
30
+ from llama_index.core.storage.docstore import SimpleDocumentStore
31
+ from llama_index.llms.huggingface import HuggingFaceLLM
32
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
33
+ from tqdm import tqdm
34
+
35
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
+ logging.basicConfig(level=logging.INFO)
37
+ logger = logging.getLogger(__name__)
38
+ hf_logging.set_verbosity_error()
39
+
40
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
41
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
42
+ HF_TOKEN = os.environ.get("HF_TOKEN")
43
+ if not HF_TOKEN:
44
+ raise ValueError("HF_TOKEN not found in environment variables")
45
+
46
+ # --- UI Settings ---
47
+ TITLE = "<h1 style='text-align:center; margin-bottom: 20px;'>Local Thinking RAG: Llama 3.1 8B</h1>"
48
+ DISCORD_BADGE = """<p style="text-align:center; margin-top: -10px;">
49
+ <a href="https://discord.gg/openfreeai" target="_blank">
50
+ <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge">
51
+ </a>
52
+ </p>
53
+ """
54
+
55
+ CSS = """
56
+ .upload-section {
57
+ max-width: 400px;
58
+ margin: 0 auto;
59
+ padding: 10px;
60
+ border: 2px dashed #ccc;
61
+ border-radius: 10px;
62
+ }
63
+ .upload-button {
64
+ background: #34c759 !important;
65
+ color: white !important;
66
+ border-radius: 25px !important;
67
+ }
68
+ .chatbot-container {
69
+ margin-top: 20px;
70
+ }
71
+ .status-output {
72
+ margin-top: 10px;
73
+ font-size: 14px;
74
+ }
75
+ .processing-info {
76
+ margin-top: 5px;
77
+ font-size: 12px;
78
+ color: #666;
79
+ }
80
+ .info-container {
81
+ margin-top: 10px;
82
+ padding: 10px;
83
+ border-radius: 5px;
84
+ }
85
+ .file-list {
86
+ margin-top: 0;
87
+ max-height: 200px;
88
+ overflow-y: auto;
89
+ padding: 5px;
90
+ border: 1px solid #eee;
91
+ border-radius: 5px;
92
+ }
93
+ .stats-box {
94
+ margin-top: 10px;
95
+ padding: 10px;
96
+ border-radius: 5px;
97
+ font-size: 12px;
98
+ }
99
+ .submit-btn {
100
+ background: #1a73e8 !important;
101
+ color: white !important;
102
+ border-radius: 25px !important;
103
+ margin-left: 10px;
104
+ padding: 5px 10px;
105
+ font-size: 16px;
106
+ }
107
+ .input-row {
108
+ display: flex;
109
+ align-items: center;
110
+ }
111
+ @media (min-width: 768px) {
112
+ .main-container {
113
+ display: flex;
114
+ justify-content: space-between;
115
+ gap: 20px;
116
+ }
117
+ .upload-section {
118
+ flex: 1;
119
+ max-width: 300px;
120
+ }
121
+ .chatbot-container {
122
+ flex: 2;
123
+ margin-top: 0;
124
+ }
125
+ }
126
+ """
127
+
128
+ global_model = None
129
+ global_tokenizer = None
130
+ global_file_info = {}
131
+
132
+ def initialize_model_and_tokenizer():
133
+ global global_model, global_tokenizer
134
+ if global_model is None or global_tokenizer is None:
135
+ logger.info("Initializing model and tokenizer...")
136
+ global_tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
137
+ global_model = AutoModelForCausalLM.from_pretrained(
138
+ MODEL,
139
+ device_map="auto",
140
+ trust_remote_code=True,
141
+ token=HF_TOKEN,
142
+ torch_dtype=torch.float16
143
+ )
144
+ logger.info("Model and tokenizer initialized successfully")
145
+
146
+ def get_llm(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
147
+ global global_model, global_tokenizer
148
+ if global_model is None or global_tokenizer is None:
149
+ initialize_model_and_tokenizer()
150
+
151
+ return HuggingFaceLLM(
152
+ context_window=4096,
153
+ max_new_tokens=max_new_tokens,
154
+ tokenizer=global_tokenizer,
155
+ model=global_model,
156
+ generate_kwargs={
157
+ "do_sample": True,
158
+ "temperature": temperature,
159
+ "top_k": top_k,
160
+ "top_p": top_p
161
+ }
162
+ )
163
+
164
+ def extract_text_from_document(file):
165
+ file_name = file.name
166
+ file_extension = os.path.splitext(file_name)[1].lower()
167
+
168
+ if file_extension == '.txt':
169
+ text = file.read().decode('utf-8')
170
+ return text, len(text.split()), None
171
+ elif file_extension == '.pdf':
172
+ pdf_reader = PyPDF2.PdfReader(file)
173
+ text = "\n\n".join(page.extract_text() for page in pdf_reader.pages)
174
+ return text, len(text.split()), None
175
+ else:
176
+ return None, 0, ValueError(f"Unsupported file format: {file_extension}")
177
+
178
+ @spaces.GPU()
179
+ def create_or_update_index(files, request: gr.Request):
180
+ global global_file_info
181
+
182
+ if not files:
183
+ return "Please provide files.", ""
184
+
185
+ start_time = time.time()
186
+ user_id = request.session_hash
187
+ save_dir = f"./{user_id}_index"
188
+ # Initialize LlamaIndex modules
189
+ llm = get_llm()
190
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
191
+ Settings.llm = llm
192
+ Settings.embed_model = embed_model
193
+ file_stats = []
194
+ new_documents = []
195
+
196
+ for file in tqdm(files, desc="Processing files"):
197
+ file_basename = os.path.basename(file.name)
198
+ text, word_count, error = extract_text_from_document(file)
199
+ if error:
200
+ logger.error(f"Error processing file {file_basename}: {str(error)}")
201
+ file_stats.append({
202
+ "name": file_basename,
203
+ "words": 0,
204
+ "status": f"error: {str(error)}"
205
+ })
206
+ continue
207
+
208
+ doc = LlamaDocument(
209
+ text=text,
210
+ metadata={
211
+ "file_name": file_basename,
212
+ "word_count": word_count,
213
+ "source": "user_upload"
214
+ }
215
+ )
216
+ new_documents.append(doc)
217
+
218
+ file_stats.append({
219
+ "name": file_basename,
220
+ "words": word_count,
221
+ "status": "processed"
222
+ })
223
+
224
+ global_file_info[file_basename] = {
225
+ "word_count": word_count,
226
+ "processed_at": time.time()
227
+ }
228
+
229
+ node_parser = HierarchicalNodeParser.from_defaults(
230
+ chunk_sizes=[2048, 512, 128],
231
+ chunk_overlap=20
232
+ )
233
+ logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
234
+ new_nodes = node_parser.get_nodes_from_documents(new_documents)
235
+ new_leaf_nodes = get_leaf_nodes(new_nodes)
236
+ new_root_nodes = get_root_nodes(new_nodes)
237
+ logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
238
+
239
+ if os.path.exists(save_dir):
240
+ logger.info(f"Loading existing index from {save_dir}")
241
+ storage_context = StorageContext.from_defaults(persist_dir=save_dir)
242
+ index = load_index_from_storage(storage_context, settings=Settings)
243
+ docstore = storage_context.docstore
244
+
245
+ docstore.add_documents(new_nodes)
246
+ for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
247
+ index.insert_nodes([node])
248
+
249
+ total_docs = len(docstore.docs)
250
+ logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
251
+ else:
252
+ logger.info("Creating new index")
253
+ docstore = SimpleDocumentStore()
254
+ storage_context = StorageContext.from_defaults(docstore=docstore)
255
+ docstore.add_documents(new_nodes)
256
+
257
+ index = VectorStoreIndex(
258
+ new_leaf_nodes,
259
+ storage_context=storage_context,
260
+ settings=Settings
261
+ )
262
+ total_docs = len(new_documents)
263
+ logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
264
+
265
+ index.storage_context.persist(persist_dir=save_dir)
266
+ # custom outputs after processing files
267
+ file_list_html = "<div class='file-list'>"
268
+ for stat in file_stats:
269
+ status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
270
+ file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>"
271
+ file_list_html += "</div>"
272
+ processing_time = time.time() - start_time
273
+ stats_output = f"<div class='stats-box'>"
274
+ stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>"
275
+ stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>"
276
+ stats_output += f"✓ Total documents in index: {total_docs}<br>"
277
+ stats_output += f"✓ Index saved to: {save_dir}<br>"
278
+ stats_output += "</div>"
279
+ output_container = f"<div class='info-container'>"
280
+ output_container += file_list_html
281
+ output_container += stats_output
282
+ output_container += "</div>"
283
+ return f"Successfully indexed {len(files)} files.", output_container
284
+
285
+ @spaces.GPU()
286
+ def stream_chat(
287
+ message: str,
288
+ history: list,
289
+ system_prompt: str,
290
+ temperature: float,
291
+ max_new_tokens: int,
292
+ top_p: float,
293
+ top_k: int,
294
+ penalty: float,
295
+ retriever_k: int,
296
+ merge_threshold: float,
297
+ request: gr.Request
298
+ ):
299
+ if not request:
300
+ yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}]
301
+ return
302
+ user_id = request.session_hash
303
+ index_dir = f"./{user_id}_index"
304
+ if not os.path.exists(index_dir):
305
+ yield history + [{"role": "assistant", "content": "Please upload documents first."}]
306
+ return
307
+
308
+ max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 1024
309
+ temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.9
310
+ top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
311
+ top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
312
+ penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
313
+ retriever_k = int(retriever_k) if isinstance(retriever_k, (int, float)) else 15
314
+ merge_threshold = float(merge_threshold) if isinstance(merge_threshold, (int, float)) else 0.5
315
+ llm = get_llm(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k)
316
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
317
+ Settings.llm = llm
318
+ Settings.embed_model = embed_model
319
+ storage_context = StorageContext.from_defaults(persist_dir=index_dir)
320
+ index = load_index_from_storage(storage_context, settings=Settings)
321
+ base_retriever = index.as_retriever(similarity_top_k=retriever_k)
322
+ auto_merging_retriever = AutoMergingRetriever(
323
+ base_retriever,
324
+ storage_context=storage_context,
325
+ simple_ratio_thresh=merge_threshold,
326
+ verbose=True
327
+ )
328
+ logger.info(f"Query: {message}")
329
+ retrieval_start = time.time()
330
+ base_nodes = base_retriever.retrieve(message)
331
+ logger.info(f"Retrieved {len(base_nodes)} base nodes in {time.time() - retrieval_start:.2f}s")
332
+ base_file_sources = {}
333
+ for node in base_nodes:
334
+ if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
335
+ file_name = node.node.metadata['file_name']
336
+ if file_name not in base_file_sources:
337
+ base_file_sources[file_name] = 0
338
+ base_file_sources[file_name] += 1
339
+ logger.info(f"Base retrieval file distribution: {base_file_sources}")
340
+ merging_start = time.time()
341
+ merged_nodes = auto_merging_retriever.retrieve(message)
342
+ logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - merging_start:.2f}s")
343
+ merged_file_sources = {}
344
+ for node in merged_nodes:
345
+ if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
346
+ file_name = node.node.metadata['file_name']
347
+ if file_name not in merged_file_sources:
348
+ merged_file_sources[file_name] = 0
349
+ merged_file_sources[file_name] += 1
350
+ logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
351
+ context = "\n\n".join([n.node.text for n in merged_nodes])
352
+ source_info = ""
353
+ if merged_file_sources:
354
+ source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
355
+ formatted_system_prompt = f"{system_prompt}\n\nDocument Context:\n{context}{source_info}"
356
+ messages = [{"role": "system", "content": formatted_system_prompt}]
357
+ for entry in history:
358
+ messages.append(entry)
359
+ messages.append({"role": "user", "content": message})
360
+ prompt = global_tokenizer.apply_chat_template(
361
+ messages,
362
+ tokenize=False,
363
+ add_generation_prompt=True
364
+ )
365
+ stop_event = threading.Event()
366
+ class StopOnEvent(StoppingCriteria):
367
+ def __init__(self, stop_event):
368
+ super().__init__()
369
+ self.stop_event = stop_event
370
+
371
+ def __call__(self, input_ids, scores, **kwargs):
372
+ return self.stop_event.is_set()
373
+ stopping_criteria = StoppingCriteriaList([StopOnEvent(stop_event)])
374
+ streamer = TextIteratorStreamer(
375
+ global_tokenizer,
376
+ skip_prompt=True,
377
+ skip_special_tokens=True
378
+ )
379
+ inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
380
+ generation_kwargs = dict(
381
+ inputs,
382
+ streamer=streamer,
383
+ max_new_tokens=max_new_tokens,
384
+ temperature=temperature,
385
+ top_p=top_p,
386
+ top_k=top_k,
387
+ repetition_penalty=penalty,
388
+ do_sample=True,
389
+ stopping_criteria=stopping_criteria
390
+ )
391
+ thread = threading.Thread(target=global_model.generate, kwargs=generation_kwargs)
392
+ thread.start()
393
+ updated_history = history + [
394
+ {"role": "user", "content": message},
395
+ {"role": "assistant", "content": ""}
396
+ ]
397
+ yield updated_history
398
+ partial_response = ""
399
+ try:
400
+ for new_text in streamer:
401
+ partial_response += new_text
402
+ updated_history[-1]["content"] = partial_response
403
+ yield updated_history
404
+ yield updated_history
405
+ except GeneratorExit:
406
+ stop_event.set()
407
+ thread.join()
408
+ raise
409
+
410
+ def create_demo():
411
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
412
+ # Title
413
+ gr.HTML(TITLE)
414
+ # Discord badge immediately under the title
415
+ gr.HTML(DISCORD_BADGE)
416
+
417
+ with gr.Row(elem_classes="main-container"):
418
+ with gr.Column(elem_classes="upload-section"):
419
+ file_upload = gr.File(
420
+ file_count="multiple",
421
+ label="Drag & Drop PDF/TXT Files Here",
422
+ file_types=[".pdf", ".txt"],
423
+ elem_id="file-upload"
424
+ )
425
+ upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
426
+ status_output = gr.Textbox(
427
+ label="Status",
428
+ placeholder="Upload files to start...",
429
+ interactive=False
430
+ )
431
+ file_info_output = gr.HTML(
432
+ label="File Information",
433
+ elem_classes="processing-info"
434
+ )
435
+ upload_button.click(
436
+ fn=create_or_update_index,
437
+ inputs=[file_upload],
438
+ outputs=[status_output, file_info_output]
439
+ )
440
+
441
+ with gr.Column(elem_classes="chatbot-container"):
442
+ chatbot = gr.Chatbot(
443
+ height=500,
444
+ placeholder="Chat with your documents...",
445
+ show_label=False,
446
+ type="messages"
447
+ )
448
+ with gr.Row(elem_classes="input-row"):
449
+ message_input = gr.Textbox(
450
+ placeholder="Type your question here...",
451
+ show_label=False,
452
+ container=False,
453
+ lines=1,
454
+ scale=8
455
+ )
456
+ submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
457
+
458
+ with gr.Accordion("Advanced Settings", open=False):
459
+ system_prompt = gr.Textbox(
460
+ value="You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. As a knowledgeable assistant, provide detailed answers using the relevant information from all uploaded documents.",
461
+ label="System Prompt",
462
+ lines=3
463
+ )
464
+
465
+ with gr.Tab("Generation Parameters"):
466
+ temperature = gr.Slider(
467
+ minimum=0,
468
+ maximum=1,
469
+ step=0.1,
470
+ value=0.9,
471
+ label="Temperature"
472
+ )
473
+ max_new_tokens = gr.Slider(
474
+ minimum=128,
475
+ maximum=8192,
476
+ step=64,
477
+ value=1024,
478
+ label="Max New Tokens",
479
+ )
480
+ top_p = gr.Slider(
481
+ minimum=0.0,
482
+ maximum=1.0,
483
+ step=0.1,
484
+ value=0.95,
485
+ label="Top P"
486
+ )
487
+ top_k = gr.Slider(
488
+ minimum=1,
489
+ maximum=100,
490
+ step=1,
491
+ value=50,
492
+ label="Top K"
493
+ )
494
+ penalty = gr.Slider(
495
+ minimum=0.0,
496
+ maximum=2.0,
497
+ step=0.1,
498
+ value=1.2,
499
+ label="Repetition Penalty"
500
+ )
501
+
502
+ with gr.Tab("Retrieval Parameters"):
503
+ retriever_k = gr.Slider(
504
+ minimum=5,
505
+ maximum=30,
506
+ step=1,
507
+ value=15,
508
+ label="Initial Retrieval Size (Top K)"
509
+ )
510
+ merge_threshold = gr.Slider(
511
+ minimum=0.1,
512
+ maximum=0.9,
513
+ step=0.1,
514
+ value=0.5,
515
+ label="Merge Threshold (lower = more merging)"
516
+ )
517
+
518
+ submit_button.click(
519
+ fn=stream_chat,
520
+ inputs=[
521
+ message_input,
522
+ chatbot,
523
+ system_prompt,
524
+ temperature,
525
+ max_new_tokens,
526
+ top_p,
527
+ top_k,
528
+ penalty,
529
+ retriever_k,
530
+ merge_threshold
531
+ ],
532
+ outputs=chatbot
533
+ )
534
+
535
+ message_input.submit(
536
+ fn=stream_chat,
537
+ inputs=[
538
+ message_input,
539
+ chatbot,
540
+ system_prompt,
541
+ temperature,
542
+ max_new_tokens,
543
+ top_p,
544
+ top_k,
545
+ penalty,
546
+ retriever_k,
547
+ merge_threshold
548
+ ],
549
+ outputs=chatbot
550
+ )
551
+ return demo
552
+
553
+ if __name__ == "__main__":
554
+ initialize_model_and_tokenizer()
555
+ demo = create_demo()
556
+ demo.launch()