raksama19 commited on
Commit
42a58bc
Β·
verified Β·
1 Parent(s): 9f7d9e3

Upload 2 files

Browse files
Files changed (2) hide show
  1. gradio_gemma.py +351 -0
  2. requirements_gemma.txt +17 -0
gradio_gemma.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone RAG Chatbot with Gemma 3n
3
+ A simple PDF chatbot using Retrieval-Augmented Generation
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import os
9
+ import io
10
+ import numpy as np
11
+ from PIL import Image
12
+ import pymupdf # PyMuPDF for PDF processing
13
+
14
+ # RAG dependencies
15
+ try:
16
+ from sentence_transformers import SentenceTransformer
17
+ from sklearn.metrics.pairwise import cosine_similarity
18
+ from transformers import Gemma3nForConditionalGeneration, AutoProcessor
19
+ RAG_AVAILABLE = True
20
+ except ImportError as e:
21
+ print(f"Missing dependencies: {e}")
22
+ RAG_AVAILABLE = False
23
+
24
+ # Global variables
25
+ embedding_model = None
26
+ chatbot_model = None
27
+ chatbot_processor = None
28
+ document_chunks = []
29
+ document_embeddings = None
30
+ processed_text = ""
31
+
32
+ def initialize_models():
33
+ """Initialize embedding model and chatbot model"""
34
+ global embedding_model, chatbot_model, chatbot_processor
35
+
36
+ if not RAG_AVAILABLE:
37
+ return False, "Required dependencies not installed"
38
+
39
+ try:
40
+ # Initialize embedding model (CPU to save GPU memory)
41
+ print("Loading embedding model...")
42
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
43
+ print("βœ… Embedding model loaded successfully")
44
+
45
+ # Initialize chatbot model
46
+ hf_token = os.getenv('HF_TOKEN')
47
+ if not hf_token:
48
+ return False, "HF_TOKEN not found in environment"
49
+
50
+ print("Loading Gemma 3n model...")
51
+ chatbot_model = Gemma3nForConditionalGeneration.from_pretrained(
52
+ "google/gemma-3n-e4b-it",
53
+ device_map="auto",
54
+ torch_dtype=torch.bfloat16,
55
+ token=hf_token
56
+ ).eval()
57
+
58
+ chatbot_processor = AutoProcessor.from_pretrained(
59
+ "google/gemma-3n-e4b-it",
60
+ token=hf_token
61
+ )
62
+
63
+ print("βœ… Gemma 3n model loaded successfully")
64
+ return True, "All models loaded successfully"
65
+
66
+ except Exception as e:
67
+ print(f"Error loading models: {e}")
68
+ return False, f"Error: {str(e)}"
69
+
70
+ def extract_text_from_pdf(pdf_file):
71
+ """Extract text from uploaded PDF file"""
72
+ try:
73
+ if isinstance(pdf_file, str):
74
+ # File path
75
+ pdf_document = pymupdf.open(pdf_file)
76
+ else:
77
+ # File object
78
+ pdf_bytes = pdf_file.read()
79
+ pdf_document = pymupdf.open(stream=pdf_bytes, filetype="pdf")
80
+
81
+ text_content = ""
82
+ for page_num in range(len(pdf_document)):
83
+ page = pdf_document[page_num]
84
+ text_content += f"\n--- Page {page_num + 1} ---\n"
85
+ text_content += page.get_text()
86
+
87
+ pdf_document.close()
88
+ return text_content
89
+
90
+ except Exception as e:
91
+ raise Exception(f"Error extracting text from PDF: {str(e)}")
92
+
93
+ def chunk_text(text, chunk_size=500, overlap=50):
94
+ """Split text into overlapping chunks"""
95
+ words = text.split()
96
+ chunks = []
97
+
98
+ for i in range(0, len(words), chunk_size - overlap):
99
+ chunk = ' '.join(words[i:i + chunk_size])
100
+ if chunk.strip():
101
+ chunks.append(chunk)
102
+
103
+ return chunks
104
+
105
+ def create_embeddings(chunks):
106
+ """Create embeddings for text chunks"""
107
+ if embedding_model is None:
108
+ return None
109
+
110
+ try:
111
+ print(f"Creating embeddings for {len(chunks)} chunks...")
112
+ embeddings = embedding_model.encode(chunks, show_progress_bar=True)
113
+ return np.array(embeddings)
114
+ except Exception as e:
115
+ print(f"Error creating embeddings: {e}")
116
+ return None
117
+
118
+ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
119
+ """Retrieve most relevant chunks for a question"""
120
+ if embedding_model is None or embeddings is None:
121
+ return chunks[:top_k]
122
+
123
+ try:
124
+ question_embedding = embedding_model.encode([question])
125
+ similarities = cosine_similarity(question_embedding, embeddings)[0]
126
+
127
+ # Get top-k most similar chunks
128
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
129
+ relevant_chunks = [chunks[i] for i in top_indices]
130
+
131
+ return relevant_chunks
132
+ except Exception as e:
133
+ print(f"Error retrieving chunks: {e}")
134
+ return chunks[:top_k]
135
+
136
+ def process_pdf(pdf_file, progress=gr.Progress()):
137
+ """Process uploaded PDF and prepare for Q&A"""
138
+ global document_chunks, document_embeddings, processed_text
139
+
140
+ if pdf_file is None:
141
+ return "❌ Please upload a PDF file first"
142
+
143
+ try:
144
+ # Extract text from PDF
145
+ progress(0.2, desc="Extracting text from PDF...")
146
+ text = extract_text_from_pdf(pdf_file)
147
+
148
+ if not text.strip():
149
+ return "❌ No text found in PDF"
150
+
151
+ processed_text = text
152
+
153
+ # Create chunks
154
+ progress(0.4, desc="Creating text chunks...")
155
+ document_chunks = chunk_text(text)
156
+
157
+ # Create embeddings
158
+ progress(0.6, desc="Creating embeddings...")
159
+ document_embeddings = create_embeddings(document_chunks)
160
+
161
+ if document_embeddings is None:
162
+ return "❌ Failed to create embeddings"
163
+
164
+ progress(1.0, desc="PDF processed successfully!")
165
+ return f"βœ… PDF processed successfully! Created {len(document_chunks)} chunks. You can now ask questions about the document."
166
+
167
+ except Exception as e:
168
+ return f"❌ Error processing PDF: {str(e)}"
169
+
170
+ def chat_with_pdf(message, history):
171
+ """Generate response using RAG"""
172
+ if not message.strip():
173
+ return history
174
+
175
+ if not processed_text:
176
+ return history + [[message, "❌ Please upload and process a PDF first"]]
177
+
178
+ if chatbot_model is None or chatbot_processor is None:
179
+ return history + [[message, "❌ Chatbot model not loaded"]]
180
+
181
+ try:
182
+ # Retrieve relevant chunks
183
+ if document_chunks and document_embeddings is not None:
184
+ relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
185
+ context = "\n\n".join(relevant_chunks)
186
+ else:
187
+ # Fallback to truncated text
188
+ context = processed_text[:2000] + "..." if len(processed_text) > 2000 else processed_text
189
+
190
+ # Create messages for Gemma
191
+ messages = [
192
+ {
193
+ "role": "system",
194
+ "content": [{"type": "text", "text": "You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely."}]
195
+ },
196
+ {
197
+ "role": "user",
198
+ "content": [{"type": "text", "text": f"Context:\n{context}\n\nQuestion: {message}"}]
199
+ }
200
+ ]
201
+
202
+ # Process with Gemma
203
+ inputs = chatbot_processor.apply_chat_template(
204
+ messages,
205
+ add_generation_prompt=True,
206
+ tokenize=True,
207
+ return_dict=True,
208
+ return_tensors="pt"
209
+ ).to(chatbot_model.device)
210
+
211
+ input_len = inputs["input_ids"].shape[-1]
212
+
213
+ with torch.inference_mode():
214
+ generation = chatbot_model.generate(
215
+ **inputs,
216
+ max_new_tokens=300,
217
+ do_sample=False,
218
+ temperature=0.7,
219
+ pad_token_id=chatbot_processor.tokenizer.pad_token_id,
220
+ use_cache=True
221
+ )
222
+ generation = generation[0][input_len:]
223
+
224
+ response = chatbot_processor.decode(generation, skip_special_tokens=True)
225
+
226
+ return history + [[message, response]]
227
+
228
+ except Exception as e:
229
+ error_msg = f"❌ Error generating response: {str(e)}"
230
+ return history + [[message, error_msg]]
231
+
232
+ def clear_chat():
233
+ """Clear chat history and processed data"""
234
+ global document_chunks, document_embeddings, processed_text
235
+ document_chunks = []
236
+ document_embeddings = None
237
+ processed_text = ""
238
+
239
+ # Clear GPU cache
240
+ if torch.cuda.is_available():
241
+ torch.cuda.empty_cache()
242
+
243
+ return [], "Ready to process a new PDF"
244
+
245
+ # Initialize models on startup
246
+ model_status = "⏳ Initializing models..."
247
+ if RAG_AVAILABLE:
248
+ success, message = initialize_models()
249
+ model_status = "βœ… Models ready" if success else f"❌ {message}"
250
+ else:
251
+ model_status = "❌ Dependencies not installed"
252
+
253
+ # Create Gradio interface
254
+ with gr.Blocks(
255
+ title="RAG Chatbot with Gemma 3n",
256
+ theme=gr.themes.Soft(),
257
+ css="""
258
+ .main-container { max-width: 1200px; margin: 0 auto; }
259
+ .status-box { padding: 15px; margin: 10px 0; border-radius: 8px; }
260
+ .chat-container { height: 500px; }
261
+ """
262
+ ) as demo:
263
+
264
+ gr.Markdown("# πŸ€– RAG Chatbot with Gemma 3n")
265
+ gr.Markdown("### Upload a PDF and ask questions about it using Retrieval-Augmented Generation")
266
+
267
+ with gr.Row():
268
+ gr.Markdown(f"**Status:** {model_status}")
269
+
270
+ with gr.Row():
271
+ # Left column - PDF upload
272
+ with gr.Column(scale=1):
273
+ gr.Markdown("## πŸ“„ Upload PDF")
274
+
275
+ pdf_input = gr.File(
276
+ file_types=[".pdf"],
277
+ label="Upload PDF Document"
278
+ )
279
+
280
+ process_btn = gr.Button(
281
+ "πŸ”„ Process PDF",
282
+ variant="primary",
283
+ size="lg"
284
+ )
285
+
286
+ status_output = gr.Markdown(
287
+ "Upload a PDF to get started",
288
+ elem_classes="status-box"
289
+ )
290
+
291
+ clear_btn = gr.Button(
292
+ "πŸ—‘οΈ Clear All",
293
+ variant="secondary"
294
+ )
295
+
296
+ # Right column - Chat
297
+ with gr.Column(scale=2):
298
+ gr.Markdown("## πŸ’¬ Ask Questions")
299
+
300
+ chatbot = gr.Chatbot(
301
+ value=[],
302
+ height=400,
303
+ elem_classes="chat-container"
304
+ )
305
+
306
+ with gr.Row():
307
+ msg_input = gr.Textbox(
308
+ placeholder="Ask a question about your PDF...",
309
+ scale=4,
310
+ container=False
311
+ )
312
+ send_btn = gr.Button("Send", variant="primary", scale=1)
313
+
314
+ # Event handlers
315
+ process_btn.click(
316
+ fn=process_pdf,
317
+ inputs=[pdf_input],
318
+ outputs=[status_output],
319
+ show_progress=True
320
+ )
321
+
322
+ send_btn.click(
323
+ fn=chat_with_pdf,
324
+ inputs=[msg_input, chatbot],
325
+ outputs=[chatbot]
326
+ ).then(
327
+ lambda: "",
328
+ outputs=[msg_input]
329
+ )
330
+
331
+ msg_input.submit(
332
+ fn=chat_with_pdf,
333
+ inputs=[msg_input, chatbot],
334
+ outputs=[chatbot]
335
+ ).then(
336
+ lambda: "",
337
+ outputs=[msg_input]
338
+ )
339
+
340
+ clear_btn.click(
341
+ fn=clear_chat,
342
+ outputs=[chatbot, status_output]
343
+ )
344
+
345
+ if __name__ == "__main__":
346
+ demo.launch(
347
+ server_name="0.0.0.0",
348
+ server_port=7860,
349
+ share=False,
350
+ show_error=True
351
+ )
requirements_gemma.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=4.0.0
3
+ torch>=2.0.0
4
+ transformers>=4.53.0
5
+ numpy>=1.24.0
6
+ Pillow>=9.0.0
7
+
8
+ # PDF processing
9
+ PyMuPDF>=1.23.0
10
+
11
+ # RAG dependencies
12
+ sentence-transformers>=2.2.0
13
+ scikit-learn>=1.3.0
14
+
15
+ # Additional utilities
16
+ accelerate>=0.20.0
17
+ safetensors>=0.3.0