Spaces:
Sleeping
Sleeping
import uuid | |
import os | |
import io | |
import time | |
from functools import lru_cache | |
from dotenv import load_dotenv | |
from database import SessionLocal, ChatMessage | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import ( | |
PointStruct, Distance, VectorParams, | |
Filter, FieldCondition, MatchValue, PointIdsList | |
) | |
from sentence_transformers import SentenceTransformer | |
from groq import Groq | |
import pdfplumber | |
from tabulate import tabulate | |
import pytesseract | |
from PIL import Image, ImageEnhance, ImageFilter | |
import fitz # PyMuPDF | |
import torch | |
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration | |
import warnings | |
warnings.filterwarnings("ignore", message="Could get FontBBox from font descriptor*") | |
# Configure Tesseract path (Windows specific) | |
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe" | |
load_dotenv() | |
# Initialize clients | |
client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
qdrant = QdrantClient( | |
url=os.getenv("QDRANT_URL"), | |
api_key=os.getenv("QDRANT_API_KEY"), | |
) | |
COLLECTION_NAME = "chatbot_sessions" | |
PDF_COLLECTION_NAME = "pdf_documents" | |
MAX_HISTORY_LENGTH = 5 | |
SUMMARY_CACHE_SIZE = 100 | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
# Initialize DePlot | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
deplot_processor = Pix2StructProcessor.from_pretrained("google/deplot") | |
deplot_model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(device) | |
def create_collections(): | |
"""Initialize Qdrant collections if they don't exist""" | |
existing_collections = [c.name for c in qdrant.get_collections().collections] | |
if COLLECTION_NAME not in existing_collections: | |
qdrant.recreate_collection( | |
collection_name=COLLECTION_NAME, | |
vectors_config=VectorParams(size=384, distance=Distance.COSINE), | |
timeout=1200 | |
) | |
qdrant.create_payload_index( | |
collection_name=COLLECTION_NAME, | |
field_name="session_id", | |
field_schema="keyword" | |
) | |
if PDF_COLLECTION_NAME not in existing_collections: | |
qdrant.recreate_collection( | |
collection_name=PDF_COLLECTION_NAME, | |
vectors_config=VectorParams(size=384, distance=Distance.COSINE), | |
timeout=1200 | |
) | |
qdrant.create_payload_index( | |
collection_name=PDF_COLLECTION_NAME, | |
field_name="document_id", | |
field_schema="keyword" | |
) | |
def generate_session_id(): | |
"""Generate a unique session ID""" | |
return str(uuid.uuid4()) | |
def store_message(session_id, role, message): | |
"""Store message in both database and vector store""" | |
db = SessionLocal() | |
chat_record = ChatMessage(session_id=session_id, role=role, message=message) | |
db.add(chat_record) | |
db.commit() | |
db.refresh(chat_record) | |
db.close() | |
# Store in vector database | |
embedding = embedder.encode(message).tolist() | |
point = PointStruct( | |
id=int(uuid.uuid4().int % 1e12), | |
vector=embedding, | |
payload={ | |
"session_id": session_id, | |
"role": role, | |
"message": message, | |
"timestamp": int(time.time()) | |
} | |
) | |
qdrant.upsert(collection_name=COLLECTION_NAME, points=[point]) | |
# Clean up old messages | |
existing = qdrant.scroll( | |
collection_name=COLLECTION_NAME, | |
scroll_filter=Filter(must=[ | |
FieldCondition(key="session_id", match=MatchValue(value=session_id)) | |
]), | |
limit=100, | |
with_payload=True | |
) | |
if len(existing[0]) > MAX_HISTORY_LENGTH: | |
old_points = sorted(existing[0], key=lambda x: x.payload.get("timestamp", 0)) | |
old_ids = [p.id for p in old_points[:-MAX_HISTORY_LENGTH]] | |
qdrant.delete( | |
collection_name=COLLECTION_NAME, | |
points_selector=PointIdsList(points=old_ids) | |
) | |
def get_conversation_summary(session_id): | |
"""Generate a concise summary of the conversation""" | |
db = SessionLocal() | |
messages = db.query(ChatMessage).filter( | |
ChatMessage.session_id == session_id | |
).order_by(ChatMessage.id).all() | |
db.close() | |
if not messages: | |
return "No previous conversation history" | |
conversation = "\n".join( | |
f"{msg.role}: {msg.message}" for msg in messages[-10:] | |
) | |
summary_prompt = ( | |
"Create a very concise summary (1-2 sentences max) focusing on:\n" | |
"1. Main topic being discussed\n" | |
"2. Any specific numbers/dates mentioned\n" | |
"3. The most recent question\n\n" | |
"Conversation:\n" + conversation | |
) | |
try: | |
response = client.chat.completions.create( | |
model="meta-llama/llama-4-scout-17b-16e-instruct", | |
messages=[{"role": "user", "content": summary_prompt}], | |
temperature=0.3, | |
max_tokens=100 | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
print(f"Summary generation failed: {e}") | |
return "Current conversation context unavailable" | |
def get_session_history(session_id): | |
"""Retrieve conversation history from vector store""" | |
result = qdrant.scroll( | |
collection_name=COLLECTION_NAME, | |
scroll_filter=Filter(must=[ | |
FieldCondition(key="session_id", match=MatchValue(value=session_id)) | |
]), | |
limit=MAX_HISTORY_LENGTH, | |
with_payload=True | |
) | |
messages = sorted(result[0], key=lambda x: x.payload.get("timestamp", 0)) | |
return [{"role": p.payload["role"], "content": p.payload["message"]} for p in messages] | |
def extract_pdf_content(pdf_path): | |
"""Extract text and images from PDF""" | |
full_text = "" | |
images = [] | |
with pdfplumber.open(pdf_path) as pdf: | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: | |
full_text += page_text + "\n\n" | |
tables = page.extract_tables() | |
for table in tables: | |
formatted_table = tabulate(table, headers="firstrow", tablefmt="grid") | |
full_text += f"\n\nTABLE:\n{formatted_table}\n\n" | |
if page.images: | |
page_image = page.to_image(resolution=300) | |
for img in page.images: | |
try: | |
bbox = (img["x0"], img["top"], img["x1"], img["bottom"]) | |
cropped = page_image.original.crop(bbox) | |
images.append(cropped) | |
except Exception as e: | |
print(f"Image extraction failed: {e}") | |
return full_text, images | |
def extract_chart_data(image: Image.Image) -> str: | |
"""Extract text from chart images using OCR""" | |
try: | |
image = image.convert("L") | |
image = image.filter(ImageFilter.SHARPEN) | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(2.0) | |
chart_text = pytesseract.image_to_string(image, config="--psm 6") | |
if chart_text.strip(): | |
return f"Chart contains: {chart_text.strip()}" | |
else: | |
width, height = image.size | |
return f"Visual chart approximately {width}x{height} pixels with data points" | |
except Exception as e: | |
return f"[Chart content could not be extracted: {str(e)}]" | |
def extract_charts_with_deplot(pdf_path: str, document_id: str, chunk_size: int = 500): | |
""" | |
Extract charts from PDF using DePlot and store in vector database | |
Args: | |
pdf_path: Path to PDF file | |
document_id: Unique document identifier | |
chunk_size: Size for text chunks | |
Returns: | |
List of processing results | |
""" | |
doc = fitz.open(pdf_path) | |
results = [] | |
for page_num in range(len(doc)): | |
page = doc[page_num] | |
image_list = page.get_images(full=True) | |
for img_index, img in enumerate(image_list): | |
try: | |
# Extract and process image | |
xref = img[0] | |
base_image = doc.extract_image(xref) | |
image_bytes = base_image["image"] | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
# Extract table data | |
text_table = "Extract all data from this chart in table format with clear headers." | |
inputs_table = deplot_processor(images=image, text=text_table, return_tensors="pt").to(device) | |
table_ids = deplot_model.generate(**inputs_table, max_new_tokens=512) | |
table_data = deplot_processor.decode(table_ids[0], skip_special_tokens=True) | |
# Generate summary | |
text_summary = ("Provide a comprehensive summary of this chart including: " | |
"1. Chart title and type, 2. Key trends and patterns, " | |
"3. Notable data points, 4. Overall conclusion.") | |
inputs_summary = deplot_processor(images=image, text=text_summary, return_tensors="pt").to(device) | |
summary_ids = deplot_model.generate(**inputs_summary, max_new_tokens=512) | |
chart_summary = deplot_processor.decode(summary_ids[0], skip_special_tokens=True) | |
# Create and store chunks | |
combined_content = f"CHART SUMMARY:\n{chart_summary}\n\nEXTRACTED DATA:\n{table_data}" | |
chunks = [] | |
current_chunk = "" | |
for para in [p for p in combined_content.split('\n') if p.strip()]: | |
if len(current_chunk) + len(para) + 1 <= chunk_size: | |
current_chunk += para + "\n" | |
else: | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
current_chunk = para + "\n" | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
# Store chunks in vector database | |
points = [] | |
for i, chunk in enumerate(chunks): | |
embedding = embedder.encode(chunk).tolist() | |
point = PointStruct( | |
id=int(uuid.uuid4().int % 1e12), | |
vector=embedding, | |
payload={ | |
"document_id": document_id, | |
"page": page_num + 1, | |
"image_index": img_index + 1, | |
"type": "chart_chunk", | |
"chunk_index": i, | |
"total_chunks": len(chunks), | |
"content": chunk, | |
"full_summary": chart_summary, | |
"full_table": table_data | |
} | |
) | |
points.append(point) | |
if points: | |
qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=points) | |
results.append({ | |
"page": page_num + 1, | |
"image_index": img_index + 1, | |
"summary": chart_summary, | |
"table_data": table_data, | |
"num_chunks": len(chunks) | |
}) | |
except Exception as e: | |
print(f"β Error processing page {page_num+1} image {img_index+1}: {str(e)}") | |
continue | |
return results | |
def store_pdf_chunks(text: str, document_id: str): | |
"""Store PDF text content in vector database""" | |
paragraphs = text.split('\n\n') | |
chunks = [] | |
current_chunk = "" | |
for para in paragraphs: | |
if len(current_chunk) + len(para) < 1000: | |
current_chunk += para + "\n\n" | |
else: | |
chunks.append(current_chunk.strip()) | |
current_chunk = para + "\n\n" | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
for chunk in chunks: | |
embedding = embedder.encode(chunk).tolist() | |
point = PointStruct( | |
id=int(uuid.uuid4().int % 1e12), | |
vector=embedding, | |
payload={ | |
"document_id": document_id, | |
"content": chunk, | |
"source": "pdf" | |
} | |
) | |
qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=[point]) | |
def process_pdf(pdf_path: str): | |
"""Process a PDF file and store its contents""" | |
text, images = extract_pdf_content(pdf_path) | |
ocr_text = "" | |
chart_summaries = [] | |
for i, image in enumerate(images): | |
try: | |
ocr_text += pytesseract.image_to_string(image) | |
chart_summary = extract_chart_data(image) | |
chart_summaries.append(f"Chart {i+1}: {chart_summary}") | |
except Exception as e: | |
print(f"Image processing failed: {e}") | |
chart_summaries.append(f"Chart {i+1}: [Content not extracted]") | |
full_text = ( | |
"PDF TEXT CONTENT:\n" + text + | |
"\n\nIMAGE TEXT CONTENT:\n" + ocr_text + | |
"\n\nCHART SUMMARIES:\n" + "\n".join(chart_summaries) | |
) | |
document_id = os.path.basename(pdf_path) | |
store_pdf_chunks(full_text, document_id) | |
# Process charts with DePlot | |
deplot_results = extract_charts_with_deplot(pdf_path, document_id) | |
print(f"β DePlot processed {len(deplot_results)} charts") | |
def get_relevant_context(user_message: str, session_id: str): | |
"""Retrieve relevant context from vector stores""" | |
question_embedding = embedder.encode(user_message).tolist() | |
# Search PDF content | |
pdf_results = qdrant.search( | |
collection_name=PDF_COLLECTION_NAME, | |
query_vector=question_embedding, | |
limit=10, | |
score_threshold=0.4 | |
) | |
# Get conversation history | |
history = get_session_history(session_id) | |
recent_history = history[-3:] | |
pdf_context = "\n".join([hit.payload.get("content", "") for hit in pdf_results]) | |
history_context = "\n".join([msg["content"] for msg in recent_history]) | |
return pdf_context, history_context | |
def get_verified_context(session_id): | |
"""Retrieve messages containing numerical data""" | |
db = SessionLocal() | |
messages = db.query(ChatMessage).filter( | |
ChatMessage.session_id == session_id | |
).order_by(ChatMessage.id.desc()).limit(10).all() | |
db.close() | |
return [msg for msg in messages if any(char.isdigit() for char in msg.message)] | |
def chat_with_session(session_id, user_message): | |
"""Main chat function with context-aware responses""" | |
try: | |
uuid_obj = uuid.UUID(session_id) | |
except ValueError: | |
return "β Invalid session ID format. Please generate a valid session." | |
# Get all context sources | |
conversation_summary = get_conversation_summary(session_id) | |
pdf_context, history_context = get_relevant_context(user_message, session_id) | |
verified_contexts = get_verified_context(session_id) | |
verified_text = "\n".join([msg.message for msg in verified_contexts]) | |
# Construct system prompt | |
system_prompt = ( | |
"You are a context-aware assistant. Follow these rules strictly:\n" | |
"1. CONVERSATION SUMMARY:\n" + conversation_summary + "\n\n" | |
"2. Maintain context for follow-up questions\n" | |
"3. DOCUMENT CONTEXT:\n" + (pdf_context if pdf_context else "None") + "\n\n" | |
"4. VERIFIED NUMERICAL CONTEXT:\n" + (verified_text if verified_text else "None") + "\n\n" | |
"5. Respond clearly and concisely to the latest user query while maintaining continuity.\n" | |
) | |
# Prepare messages for LLM | |
messages = [{"role": "system", "content": system_prompt}] | |
messages.extend(get_session_history(session_id)[-3:]) | |
messages.append({"role": "user", "content": user_message}) | |
try: | |
completion = client.chat.completions.create( | |
model="meta-llama/llama-4-scout-17b-16e-instruct", | |
messages=messages, | |
temperature=0.7, | |
max_tokens=1024, | |
top_p=0.9 | |
) | |
reply = completion.choices[0].message.content | |
except Exception as e: | |
print(f"β LLM generation failed: {e}") | |
return "Sorry, I couldn't generate a response at this time." | |
# Store conversation | |
store_message(session_id, "user", user_message) | |
store_message(session_id, "assistant", reply) | |
get_conversation_summary.cache_clear() | |
return reply | |
# Initialize collections on startup | |
create_collections() |