import uuid import os import io import time from functools import lru_cache from dotenv import load_dotenv from database import SessionLocal, ChatMessage from qdrant_client import QdrantClient from qdrant_client.models import ( PointStruct, Distance, VectorParams, Filter, FieldCondition, MatchValue, PointIdsList ) from sentence_transformers import SentenceTransformer from groq import Groq import pdfplumber from tabulate import tabulate import pytesseract from PIL import Image, ImageEnhance, ImageFilter import fitz # PyMuPDF import torch from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration import warnings warnings.filterwarnings("ignore", message="Could get FontBBox from font descriptor*") # Configure Tesseract path (Windows specific) pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe" load_dotenv() # Initialize clients client = Groq(api_key=os.getenv("GROQ_API_KEY")) qdrant = QdrantClient( url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), ) COLLECTION_NAME = "chatbot_sessions" PDF_COLLECTION_NAME = "pdf_documents" MAX_HISTORY_LENGTH = 5 SUMMARY_CACHE_SIZE = 100 embedder = SentenceTransformer("all-MiniLM-L6-v2") # Initialize DePlot device = "cuda" if torch.cuda.is_available() else "cpu" deplot_processor = Pix2StructProcessor.from_pretrained("google/deplot") deplot_model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(device) def create_collections(): """Initialize Qdrant collections if they don't exist""" existing_collections = [c.name for c in qdrant.get_collections().collections] if COLLECTION_NAME not in existing_collections: qdrant.recreate_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=384, distance=Distance.COSINE), timeout=1200 ) qdrant.create_payload_index( collection_name=COLLECTION_NAME, field_name="session_id", field_schema="keyword" ) if PDF_COLLECTION_NAME not in existing_collections: qdrant.recreate_collection( collection_name=PDF_COLLECTION_NAME, vectors_config=VectorParams(size=384, distance=Distance.COSINE), timeout=1200 ) qdrant.create_payload_index( collection_name=PDF_COLLECTION_NAME, field_name="document_id", field_schema="keyword" ) def generate_session_id(): """Generate a unique session ID""" return str(uuid.uuid4()) def store_message(session_id, role, message): """Store message in both database and vector store""" db = SessionLocal() chat_record = ChatMessage(session_id=session_id, role=role, message=message) db.add(chat_record) db.commit() db.refresh(chat_record) db.close() # Store in vector database embedding = embedder.encode(message).tolist() point = PointStruct( id=int(uuid.uuid4().int % 1e12), vector=embedding, payload={ "session_id": session_id, "role": role, "message": message, "timestamp": int(time.time()) } ) qdrant.upsert(collection_name=COLLECTION_NAME, points=[point]) # Clean up old messages existing = qdrant.scroll( collection_name=COLLECTION_NAME, scroll_filter=Filter(must=[ FieldCondition(key="session_id", match=MatchValue(value=session_id)) ]), limit=100, with_payload=True ) if len(existing[0]) > MAX_HISTORY_LENGTH: old_points = sorted(existing[0], key=lambda x: x.payload.get("timestamp", 0)) old_ids = [p.id for p in old_points[:-MAX_HISTORY_LENGTH]] qdrant.delete( collection_name=COLLECTION_NAME, points_selector=PointIdsList(points=old_ids) ) @lru_cache(maxsize=SUMMARY_CACHE_SIZE) def get_conversation_summary(session_id): """Generate a concise summary of the conversation""" db = SessionLocal() messages = db.query(ChatMessage).filter( ChatMessage.session_id == session_id ).order_by(ChatMessage.id).all() db.close() if not messages: return "No previous conversation history" conversation = "\n".join( f"{msg.role}: {msg.message}" for msg in messages[-10:] ) summary_prompt = ( "Create a very concise summary (1-2 sentences max) focusing on:\n" "1. Main topic being discussed\n" "2. Any specific numbers/dates mentioned\n" "3. The most recent question\n\n" "Conversation:\n" + conversation ) try: response = client.chat.completions.create( model="meta-llama/llama-4-scout-17b-16e-instruct", messages=[{"role": "user", "content": summary_prompt}], temperature=0.3, max_tokens=100 ) return response.choices[0].message.content.strip() except Exception as e: print(f"Summary generation failed: {e}") return "Current conversation context unavailable" def get_session_history(session_id): """Retrieve conversation history from vector store""" result = qdrant.scroll( collection_name=COLLECTION_NAME, scroll_filter=Filter(must=[ FieldCondition(key="session_id", match=MatchValue(value=session_id)) ]), limit=MAX_HISTORY_LENGTH, with_payload=True ) messages = sorted(result[0], key=lambda x: x.payload.get("timestamp", 0)) return [{"role": p.payload["role"], "content": p.payload["message"]} for p in messages] def extract_pdf_content(pdf_path): """Extract text and images from PDF""" full_text = "" images = [] with pdfplumber.open(pdf_path) as pdf: for page in pdf.pages: page_text = page.extract_text() if page_text: full_text += page_text + "\n\n" tables = page.extract_tables() for table in tables: formatted_table = tabulate(table, headers="firstrow", tablefmt="grid") full_text += f"\n\nTABLE:\n{formatted_table}\n\n" if page.images: page_image = page.to_image(resolution=300) for img in page.images: try: bbox = (img["x0"], img["top"], img["x1"], img["bottom"]) cropped = page_image.original.crop(bbox) images.append(cropped) except Exception as e: print(f"Image extraction failed: {e}") return full_text, images def extract_chart_data(image: Image.Image) -> str: """Extract text from chart images using OCR""" try: image = image.convert("L") image = image.filter(ImageFilter.SHARPEN) enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(2.0) chart_text = pytesseract.image_to_string(image, config="--psm 6") if chart_text.strip(): return f"Chart contains: {chart_text.strip()}" else: width, height = image.size return f"Visual chart approximately {width}x{height} pixels with data points" except Exception as e: return f"[Chart content could not be extracted: {str(e)}]" def extract_charts_with_deplot(pdf_path: str, document_id: str, chunk_size: int = 500): """ Extract charts from PDF using DePlot and store in vector database Args: pdf_path: Path to PDF file document_id: Unique document identifier chunk_size: Size for text chunks Returns: List of processing results """ doc = fitz.open(pdf_path) results = [] for page_num in range(len(doc)): page = doc[page_num] image_list = page.get_images(full=True) for img_index, img in enumerate(image_list): try: # Extract and process image xref = img[0] base_image = doc.extract_image(xref) image_bytes = base_image["image"] image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Extract table data text_table = "Extract all data from this chart in table format with clear headers." inputs_table = deplot_processor(images=image, text=text_table, return_tensors="pt").to(device) table_ids = deplot_model.generate(**inputs_table, max_new_tokens=512) table_data = deplot_processor.decode(table_ids[0], skip_special_tokens=True) # Generate summary text_summary = ("Provide a comprehensive summary of this chart including: " "1. Chart title and type, 2. Key trends and patterns, " "3. Notable data points, 4. Overall conclusion.") inputs_summary = deplot_processor(images=image, text=text_summary, return_tensors="pt").to(device) summary_ids = deplot_model.generate(**inputs_summary, max_new_tokens=512) chart_summary = deplot_processor.decode(summary_ids[0], skip_special_tokens=True) # Create and store chunks combined_content = f"CHART SUMMARY:\n{chart_summary}\n\nEXTRACTED DATA:\n{table_data}" chunks = [] current_chunk = "" for para in [p for p in combined_content.split('\n') if p.strip()]: if len(current_chunk) + len(para) + 1 <= chunk_size: current_chunk += para + "\n" else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = para + "\n" if current_chunk: chunks.append(current_chunk.strip()) # Store chunks in vector database points = [] for i, chunk in enumerate(chunks): embedding = embedder.encode(chunk).tolist() point = PointStruct( id=int(uuid.uuid4().int % 1e12), vector=embedding, payload={ "document_id": document_id, "page": page_num + 1, "image_index": img_index + 1, "type": "chart_chunk", "chunk_index": i, "total_chunks": len(chunks), "content": chunk, "full_summary": chart_summary, "full_table": table_data } ) points.append(point) if points: qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=points) results.append({ "page": page_num + 1, "image_index": img_index + 1, "summary": chart_summary, "table_data": table_data, "num_chunks": len(chunks) }) except Exception as e: print(f"❌ Error processing page {page_num+1} image {img_index+1}: {str(e)}") continue return results def store_pdf_chunks(text: str, document_id: str): """Store PDF text content in vector database""" paragraphs = text.split('\n\n') chunks = [] current_chunk = "" for para in paragraphs: if len(current_chunk) + len(para) < 1000: current_chunk += para + "\n\n" else: chunks.append(current_chunk.strip()) current_chunk = para + "\n\n" if current_chunk: chunks.append(current_chunk.strip()) for chunk in chunks: embedding = embedder.encode(chunk).tolist() point = PointStruct( id=int(uuid.uuid4().int % 1e12), vector=embedding, payload={ "document_id": document_id, "content": chunk, "source": "pdf" } ) qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=[point]) def process_pdf(pdf_path: str): """Process a PDF file and store its contents""" text, images = extract_pdf_content(pdf_path) ocr_text = "" chart_summaries = [] for i, image in enumerate(images): try: ocr_text += pytesseract.image_to_string(image) chart_summary = extract_chart_data(image) chart_summaries.append(f"Chart {i+1}: {chart_summary}") except Exception as e: print(f"Image processing failed: {e}") chart_summaries.append(f"Chart {i+1}: [Content not extracted]") full_text = ( "PDF TEXT CONTENT:\n" + text + "\n\nIMAGE TEXT CONTENT:\n" + ocr_text + "\n\nCHART SUMMARIES:\n" + "\n".join(chart_summaries) ) document_id = os.path.basename(pdf_path) store_pdf_chunks(full_text, document_id) # Process charts with DePlot deplot_results = extract_charts_with_deplot(pdf_path, document_id) print(f"✅ DePlot processed {len(deplot_results)} charts") def get_relevant_context(user_message: str, session_id: str): """Retrieve relevant context from vector stores""" question_embedding = embedder.encode(user_message).tolist() # Search PDF content pdf_results = qdrant.search( collection_name=PDF_COLLECTION_NAME, query_vector=question_embedding, limit=10, score_threshold=0.4 ) # Get conversation history history = get_session_history(session_id) recent_history = history[-3:] pdf_context = "\n".join([hit.payload.get("content", "") for hit in pdf_results]) history_context = "\n".join([msg["content"] for msg in recent_history]) return pdf_context, history_context def get_verified_context(session_id): """Retrieve messages containing numerical data""" db = SessionLocal() messages = db.query(ChatMessage).filter( ChatMessage.session_id == session_id ).order_by(ChatMessage.id.desc()).limit(10).all() db.close() return [msg for msg in messages if any(char.isdigit() for char in msg.message)] def chat_with_session(session_id, user_message): """Main chat function with context-aware responses""" try: uuid_obj = uuid.UUID(session_id) except ValueError: return "❌ Invalid session ID format. Please generate a valid session." # Get all context sources conversation_summary = get_conversation_summary(session_id) pdf_context, history_context = get_relevant_context(user_message, session_id) verified_contexts = get_verified_context(session_id) verified_text = "\n".join([msg.message for msg in verified_contexts]) # Construct system prompt system_prompt = ( "You are a context-aware assistant. Follow these rules strictly:\n" "1. CONVERSATION SUMMARY:\n" + conversation_summary + "\n\n" "2. Maintain context for follow-up questions\n" "3. DOCUMENT CONTEXT:\n" + (pdf_context if pdf_context else "None") + "\n\n" "4. VERIFIED NUMERICAL CONTEXT:\n" + (verified_text if verified_text else "None") + "\n\n" "5. Respond clearly and concisely to the latest user query while maintaining continuity.\n" ) # Prepare messages for LLM messages = [{"role": "system", "content": system_prompt}] messages.extend(get_session_history(session_id)[-3:]) messages.append({"role": "user", "content": user_message}) try: completion = client.chat.completions.create( model="meta-llama/llama-4-scout-17b-16e-instruct", messages=messages, temperature=0.7, max_tokens=1024, top_p=0.9 ) reply = completion.choices[0].message.content except Exception as e: print(f"❌ LLM generation failed: {e}") return "Sorry, I couldn't generate a response at this time." # Store conversation store_message(session_id, "user", user_message) store_message(session_id, "assistant", reply) get_conversation_summary.cache_clear() return reply # Initialize collections on startup create_collections()