File size: 16,608 Bytes
3a2241a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
import uuid
import os
import io
import time
from functools import lru_cache
from dotenv import load_dotenv
from database import SessionLocal, ChatMessage
from qdrant_client import QdrantClient
from qdrant_client.models import (
    PointStruct, Distance, VectorParams,
    Filter, FieldCondition, MatchValue, PointIdsList
)
from sentence_transformers import SentenceTransformer
from groq import Groq
import pdfplumber
from tabulate import tabulate
import pytesseract
from PIL import Image, ImageEnhance, ImageFilter
import fitz  # PyMuPDF
import torch
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
import warnings
warnings.filterwarnings("ignore", message="Could get FontBBox from font descriptor*")

# Configure Tesseract path (Windows specific)
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"

load_dotenv()

# Initialize clients
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
qdrant = QdrantClient(
    url=os.getenv("QDRANT_URL"),
    api_key=os.getenv("QDRANT_API_KEY"),
)

COLLECTION_NAME = "chatbot_sessions"
PDF_COLLECTION_NAME = "pdf_documents"
MAX_HISTORY_LENGTH = 5
SUMMARY_CACHE_SIZE = 100
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# Initialize DePlot
device = "cuda" if torch.cuda.is_available() else "cpu"
deplot_processor = Pix2StructProcessor.from_pretrained("google/deplot")
deplot_model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(device)

def create_collections():
    """Initialize Qdrant collections if they don't exist"""
    existing_collections = [c.name for c in qdrant.get_collections().collections]
    
    if COLLECTION_NAME not in existing_collections:
        qdrant.recreate_collection(
            collection_name=COLLECTION_NAME,
            vectors_config=VectorParams(size=384, distance=Distance.COSINE),
            timeout=1200
        )
        qdrant.create_payload_index(
            collection_name=COLLECTION_NAME,
            field_name="session_id",
            field_schema="keyword"
        )
    
    if PDF_COLLECTION_NAME not in existing_collections:
        qdrant.recreate_collection(
            collection_name=PDF_COLLECTION_NAME,
            vectors_config=VectorParams(size=384, distance=Distance.COSINE),
            timeout=1200
        )
        qdrant.create_payload_index(
            collection_name=PDF_COLLECTION_NAME,
            field_name="document_id",
            field_schema="keyword"
        )

def generate_session_id():
    """Generate a unique session ID"""
    return str(uuid.uuid4())

def store_message(session_id, role, message):
    """Store message in both database and vector store"""
    db = SessionLocal()
    chat_record = ChatMessage(session_id=session_id, role=role, message=message)
    db.add(chat_record)
    db.commit()
    db.refresh(chat_record)
    db.close()

    # Store in vector database
    embedding = embedder.encode(message).tolist()
    point = PointStruct(
        id=int(uuid.uuid4().int % 1e12),
        vector=embedding,
        payload={
            "session_id": session_id,
            "role": role,
            "message": message,
            "timestamp": int(time.time())
        }
    )
    qdrant.upsert(collection_name=COLLECTION_NAME, points=[point])

    # Clean up old messages
    existing = qdrant.scroll(
        collection_name=COLLECTION_NAME,
        scroll_filter=Filter(must=[
            FieldCondition(key="session_id", match=MatchValue(value=session_id))
        ]),
        limit=100,
        with_payload=True
    )
    if len(existing[0]) > MAX_HISTORY_LENGTH:
        old_points = sorted(existing[0], key=lambda x: x.payload.get("timestamp", 0))
        old_ids = [p.id for p in old_points[:-MAX_HISTORY_LENGTH]]
        qdrant.delete(
            collection_name=COLLECTION_NAME,
            points_selector=PointIdsList(points=old_ids)
        )

@lru_cache(maxsize=SUMMARY_CACHE_SIZE)
def get_conversation_summary(session_id):
    """Generate a concise summary of the conversation"""
    db = SessionLocal()
    messages = db.query(ChatMessage).filter(
        ChatMessage.session_id == session_id
    ).order_by(ChatMessage.id).all()
    db.close()

    if not messages:
        return "No previous conversation history"

    conversation = "\n".join(
        f"{msg.role}: {msg.message}" for msg in messages[-10:]
    )

    summary_prompt = (
        "Create a very concise summary (1-2 sentences max) focusing on:\n"
        "1. Main topic being discussed\n"
        "2. Any specific numbers/dates mentioned\n"
        "3. The most recent question\n\n"
        "Conversation:\n" + conversation
    )

    try:
        response = client.chat.completions.create(
            model="meta-llama/llama-4-scout-17b-16e-instruct",
            messages=[{"role": "user", "content": summary_prompt}],
            temperature=0.3,
            max_tokens=100
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Summary generation failed: {e}")
        return "Current conversation context unavailable"

def get_session_history(session_id):
    """Retrieve conversation history from vector store"""
    result = qdrant.scroll(
        collection_name=COLLECTION_NAME,
        scroll_filter=Filter(must=[
            FieldCondition(key="session_id", match=MatchValue(value=session_id))
        ]),
        limit=MAX_HISTORY_LENGTH,
        with_payload=True
    )
    messages = sorted(result[0], key=lambda x: x.payload.get("timestamp", 0))
    return [{"role": p.payload["role"], "content": p.payload["message"]} for p in messages]

def extract_pdf_content(pdf_path):
    """Extract text and images from PDF"""
    full_text = ""
    images = []
    
    with pdfplumber.open(pdf_path) as pdf:
        for page in pdf.pages:
            page_text = page.extract_text()
            if page_text:
                full_text += page_text + "\n\n"

            tables = page.extract_tables()
            for table in tables:
                formatted_table = tabulate(table, headers="firstrow", tablefmt="grid")
                full_text += f"\n\nTABLE:\n{formatted_table}\n\n"

            if page.images:
                page_image = page.to_image(resolution=300)
                for img in page.images:
                    try:
                        bbox = (img["x0"], img["top"], img["x1"], img["bottom"])
                        cropped = page_image.original.crop(bbox)
                        images.append(cropped)
                    except Exception as e:
                        print(f"Image extraction failed: {e}")
    
    return full_text, images

def extract_chart_data(image: Image.Image) -> str:
    """Extract text from chart images using OCR"""
    try:
        image = image.convert("L")
        image = image.filter(ImageFilter.SHARPEN)
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(2.0)
        
        chart_text = pytesseract.image_to_string(image, config="--psm 6")
        
        if chart_text.strip():
            return f"Chart contains: {chart_text.strip()}"
        else:
            width, height = image.size
            return f"Visual chart approximately {width}x{height} pixels with data points"
    except Exception as e:
        return f"[Chart content could not be extracted: {str(e)}]"

def extract_charts_with_deplot(pdf_path: str, document_id: str, chunk_size: int = 500):
    """
    Extract charts from PDF using DePlot and store in vector database
    
    Args:
        pdf_path: Path to PDF file
        document_id: Unique document identifier
        chunk_size: Size for text chunks
    
    Returns:
        List of processing results
    """
    doc = fitz.open(pdf_path)
    results = []
    
    for page_num in range(len(doc)):
        page = doc[page_num]
        image_list = page.get_images(full=True)
        
        for img_index, img in enumerate(image_list):
            try:
                # Extract and process image
                xref = img[0]
                base_image = doc.extract_image(xref)
                image_bytes = base_image["image"]
                image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                
                # Extract table data
                text_table = "Extract all data from this chart in table format with clear headers."
                inputs_table = deplot_processor(images=image, text=text_table, return_tensors="pt").to(device)
                table_ids = deplot_model.generate(**inputs_table, max_new_tokens=512)
                table_data = deplot_processor.decode(table_ids[0], skip_special_tokens=True)
                
                # Generate summary
                text_summary = ("Provide a comprehensive summary of this chart including: "
                              "1. Chart title and type, 2. Key trends and patterns, "
                              "3. Notable data points, 4. Overall conclusion.")
                inputs_summary = deplot_processor(images=image, text=text_summary, return_tensors="pt").to(device)
                summary_ids = deplot_model.generate(**inputs_summary, max_new_tokens=512)
                chart_summary = deplot_processor.decode(summary_ids[0], skip_special_tokens=True)
                
                # Create and store chunks
                combined_content = f"CHART SUMMARY:\n{chart_summary}\n\nEXTRACTED DATA:\n{table_data}"
                chunks = []
                current_chunk = ""
                
                for para in [p for p in combined_content.split('\n') if p.strip()]:
                    if len(current_chunk) + len(para) + 1 <= chunk_size:
                        current_chunk += para + "\n"
                    else:
                        if current_chunk:
                            chunks.append(current_chunk.strip())
                        current_chunk = para + "\n"
                
                if current_chunk:
                    chunks.append(current_chunk.strip())
                
                # Store chunks in vector database
                points = []
                for i, chunk in enumerate(chunks):
                    embedding = embedder.encode(chunk).tolist()
                    point = PointStruct(
                        id=int(uuid.uuid4().int % 1e12),
                        vector=embedding,
                        payload={
                            "document_id": document_id,
                            "page": page_num + 1,
                            "image_index": img_index + 1,
                            "type": "chart_chunk",
                            "chunk_index": i,
                            "total_chunks": len(chunks),
                            "content": chunk,
                            "full_summary": chart_summary,
                            "full_table": table_data
                        }
                    )
                    points.append(point)
                
                if points:
                    qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=points)
                
                results.append({
                    "page": page_num + 1,
                    "image_index": img_index + 1,
                    "summary": chart_summary,
                    "table_data": table_data,
                    "num_chunks": len(chunks)
                })
                
            except Exception as e:
                print(f"❌ Error processing page {page_num+1} image {img_index+1}: {str(e)}")
                continue
    
    return results

def store_pdf_chunks(text: str, document_id: str):
    """Store PDF text content in vector database"""
    paragraphs = text.split('\n\n')
    chunks = []
    current_chunk = ""
    
    for para in paragraphs:
        if len(current_chunk) + len(para) < 1000:
            current_chunk += para + "\n\n"
        else:
            chunks.append(current_chunk.strip())
            current_chunk = para + "\n\n"
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    for chunk in chunks:
        embedding = embedder.encode(chunk).tolist()
        point = PointStruct(
            id=int(uuid.uuid4().int % 1e12),
            vector=embedding,
            payload={
                "document_id": document_id,
                "content": chunk,
                "source": "pdf"
            }
        )
        qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=[point])

def process_pdf(pdf_path: str):
    """Process a PDF file and store its contents"""
    text, images = extract_pdf_content(pdf_path)
    ocr_text = ""
    chart_summaries = []

    for i, image in enumerate(images):
        try:
            ocr_text += pytesseract.image_to_string(image)
            chart_summary = extract_chart_data(image)
            chart_summaries.append(f"Chart {i+1}: {chart_summary}")
        except Exception as e:
            print(f"Image processing failed: {e}")
            chart_summaries.append(f"Chart {i+1}: [Content not extracted]")

    full_text = (
        "PDF TEXT CONTENT:\n" + text + 
        "\n\nIMAGE TEXT CONTENT:\n" + ocr_text + 
        "\n\nCHART SUMMARIES:\n" + "\n".join(chart_summaries)
    )
    
    document_id = os.path.basename(pdf_path)
    store_pdf_chunks(full_text, document_id)

    # Process charts with DePlot
    deplot_results = extract_charts_with_deplot(pdf_path, document_id)
    print(f"βœ… DePlot processed {len(deplot_results)} charts")

def get_relevant_context(user_message: str, session_id: str):
    """Retrieve relevant context from vector stores"""
    question_embedding = embedder.encode(user_message).tolist()
    
    # Search PDF content
    pdf_results = qdrant.search(
        collection_name=PDF_COLLECTION_NAME,
        query_vector=question_embedding,
        limit=10,
        score_threshold=0.4
    )
    
    # Get conversation history
    history = get_session_history(session_id)
    recent_history = history[-3:]
    
    pdf_context = "\n".join([hit.payload.get("content", "") for hit in pdf_results])
    history_context = "\n".join([msg["content"] for msg in recent_history])
    
    return pdf_context, history_context

def get_verified_context(session_id):
    """Retrieve messages containing numerical data"""
    db = SessionLocal()
    messages = db.query(ChatMessage).filter(
        ChatMessage.session_id == session_id
    ).order_by(ChatMessage.id.desc()).limit(10).all()
    db.close()
    
    return [msg for msg in messages if any(char.isdigit() for char in msg.message)]

def chat_with_session(session_id, user_message):
    """Main chat function with context-aware responses"""
    try:
        uuid_obj = uuid.UUID(session_id)
    except ValueError:
        return "❌ Invalid session ID format. Please generate a valid session."

    # Get all context sources
    conversation_summary = get_conversation_summary(session_id)
    pdf_context, history_context = get_relevant_context(user_message, session_id)
    verified_contexts = get_verified_context(session_id)
    verified_text = "\n".join([msg.message for msg in verified_contexts])

    # Construct system prompt
    system_prompt = (
        "You are a context-aware assistant. Follow these rules strictly:\n"
        "1. CONVERSATION SUMMARY:\n" + conversation_summary + "\n\n"
        "2. Maintain context for follow-up questions\n"
        "3. DOCUMENT CONTEXT:\n" + (pdf_context if pdf_context else "None") + "\n\n"
        "4. VERIFIED NUMERICAL CONTEXT:\n" + (verified_text if verified_text else "None") + "\n\n"
        "5. Respond clearly and concisely to the latest user query while maintaining continuity.\n"
    )

    # Prepare messages for LLM
    messages = [{"role": "system", "content": system_prompt}]
    messages.extend(get_session_history(session_id)[-3:])
    messages.append({"role": "user", "content": user_message})

    try:
        completion = client.chat.completions.create(
            model="meta-llama/llama-4-scout-17b-16e-instruct",
            messages=messages,
            temperature=0.7,
            max_tokens=1024,
            top_p=0.9
        )
        reply = completion.choices[0].message.content
    except Exception as e:
        print(f"❌ LLM generation failed: {e}")
        return "Sorry, I couldn't generate a response at this time."

    # Store conversation
    store_message(session_id, "user", user_message)
    store_message(session_id, "assistant", reply)
    get_conversation_summary.cache_clear()

    return reply

# Initialize collections on startup
create_collections()