Spaces:
Running
Running
import os | |
from smolagents import Tool | |
from dotenv import load_dotenv | |
load_dotenv() | |
class VisualRAGTool(Tool): | |
name = "visual_rag" | |
description = """Performs a RAG query on your internal PDF documents and returns the generated text response.""" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to perform. This should be semantically close to your target documents.", | |
}, | |
"k": { | |
"type": "number", | |
"description": "The number of documents to retrieve.", | |
"default": 1, | |
"nullable": True, | |
}, | |
"api_key": { | |
"type": "string", | |
"description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.", | |
"nullable": True, | |
} | |
} | |
output_type = "string" | |
model_name: str = "vidore/colqwen2-v1.0" | |
api_key: str = os.getenv("OPENAI_KEY") | |
class Page: | |
from typing import Optional, Dict, Any | |
from PIL import Image | |
image: Image.Image | |
metadata: Optional[Dict[str, Any]] = None | |
def __init__(self, image, metadata=None): | |
self.image = image | |
self.metadata = metadata | |
def caption(self): | |
if self.metadata is None: | |
return None | |
return f"Document: {self.metadata.get('doc_title')}, Context: {self.metadata.get('context')}" | |
def __hash__(self): | |
return hash(self.image) | |
def __init__(self, *args, **kwargs): | |
self.is_initialized = False | |
def _init_models(self, model_name: str) -> None: | |
import torch | |
from colpali_engine.models import ColQwen2, ColQwen2Processor | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = ColQwen2.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
attn_implementation="flash_attention_2" | |
).eval() | |
self.processor = ColQwen2Processor.from_pretrained(model_name) | |
def setup(self): | |
""" | |
Overwrite this method here for any operation that is expensive and needs to be executed before you start using | |
your tool. Such as loading a big model. | |
""" | |
self._init_models(self.model_name) | |
self.embds = [] | |
self.pages = [] | |
self.is_initialized = True | |
def _encode_image_to_base64(self, image): | |
"""Encodes a PIL image to a base64 string.""" | |
from io import BytesIO | |
import base64 | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def _build_query(self, query: str, pages: list) -> list: | |
"""Builds the query for OpenAI based on the pages and the query.""" | |
messages = [] | |
messages.append({"type": "text", "text": "PDF pages:\n"}) | |
for page in pages: | |
capt = page.caption | |
if capt is not None: | |
messages.append({ | |
"type": "text", | |
"text": capt | |
}) | |
messages.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{self._encode_image_to_base64(page.image)}" | |
}, | |
}) | |
messages.append({"type": "text", "text": f"Query:\n{query}"}) | |
return messages | |
def query_openai(self, query, pages, api_key=None, system_prompt=None, model="gpt-4o-mini"): | |
"""Calls OpenAI's GPT-4o-mini with the query and image data.""" | |
from smolagents import ChatMessage | |
system_prompt = system_prompt or \ | |
"""You are a smart assistant designed to answer questions about a PDF document. | |
You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context. | |
Use them to construct a short response to the question, and cite your sources in the following format: (document, page number). | |
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. | |
Give detailed and extensive answers, only containing info in the pages you are given. | |
You can answer using information contained in plots and figures if necessary. | |
Answer in the same language as the query.""" | |
api_key = api_key or self.api_key | |
if api_key and api_key.startswith("sk"): | |
try: | |
from openai import OpenAI | |
client = OpenAI(api_key=api_key.strip()) | |
response = client.chat.completions.create( | |
model=model, | |
messages=[ | |
{ | |
"role": "system", | |
"content": system_prompt | |
}, | |
{ | |
"role": "user", | |
"content": self._build_query(query, pages) | |
} | |
], | |
max_tokens=500, | |
) | |
message = ChatMessage.from_dict( | |
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) | |
) | |
message.raw = response | |
return message | |
except Exception as e: | |
return "OpenAI API connection failure. Verify the provided key is correct (sk-***)." | |
return "Enter your OpenAI API key to get a custom response" | |
def _extract_contexts(self, images, api_key, window=10) -> list: | |
"""Extracts context from images.""" | |
from pqdm.threads import pqdm | |
CONTEXT_SYSTEM_PROMPT = \ | |
"""You are a smart assistant designed to extract context of PDF pages. | |
Give concise answers, only containing info in the pages you are given. | |
You can answer using information contained in plots and figures if necessary.""" | |
try: | |
args = [ | |
{ | |
'query': "Give the general context about these pages. Give the context in the same language as the documents.", | |
'pages': [self.Page(image=im) for im in images[max(i-window+1, 0):i+1]], | |
'api_key': api_key, | |
'system_prompt': CONTEXT_SYSTEM_PROMPT, | |
} for i in range(0, len(images), window) | |
] | |
window_contexts = pqdm(args, self.query_openai, n_jobs=8, argument_type='kwargs') | |
# code sequentially ftm with tqdm | |
# query = "Give the general context about these pages. Give the context in the same language as the documents." | |
# window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\ | |
# for i in tqdm(range(0, len(images), window))] | |
contexts = [] | |
for i in range(len(images)): | |
context = window_contexts[i//window].content | |
contexts.append(context) | |
except Exception as e: | |
print(f"Error extracting contexts: {e}") | |
contexts = [None for _ in range(len(images))] | |
# Ensure that the number of contexts is equal to the number of images | |
assert len(contexts) == len(images) | |
return contexts | |
def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list: | |
"""Converts a file to images and extracts metadata.""" | |
from pdf2image import convert_from_path | |
title = file.split("/")[-1] | |
images = convert_from_path(file, thread_count=4) | |
if contextualize and api_key: | |
contexts = self._extract_contexts(images, api_key, window=window) | |
else: | |
contexts = [None for _ in range(len(images))] | |
metadatas = [{'doc_title': title, 'page_id': i, 'context': contexts[i]} for i in range(len(images))] | |
return [self.Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)] | |
def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list: | |
"""Preprocesses the files and extracts metadata.""" | |
pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)] | |
print(f"Example metadata:\n{pages[0].metadata.get('context')}") | |
return pages | |
def compute_embeddings(self, pages) -> list: | |
"""Embeds the images using the model.""" | |
"""Example script to run inference with ColPali (ColQwen2)""" | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
# run inference - docs | |
dataloader = DataLoader( | |
pages, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device), | |
) | |
embds = [] | |
for batch_doc in tqdm(dataloader): | |
with torch.no_grad(): | |
batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()} | |
embeddings_doc = self.model(**batch_doc) | |
embds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
return embds | |
def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int: | |
"""Indexes the uploaded files.""" | |
if not self.is_initialized: | |
self.setup() | |
print("Converting files...") | |
# Convert files to images and extract metadata | |
pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key) | |
# Embed the images | |
embds = self.compute_embeddings(pgs) | |
# Overwrite the database if necessary | |
if overwrite_db: | |
self.pages = [] | |
self.embds = [] | |
# Extend the pages | |
self.pages.extend(pgs) | |
# Extend the datasets | |
self.embds.extend(embds) | |
print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.") | |
return len(embds) | |
def retrieve(self, query: str, k: int) -> list: | |
"""Retrieve the top k documents based on the query.""" | |
import torch | |
k = min(k, len(self.embds)) | |
qs = [] | |
with torch.no_grad(): | |
batch_query = self.processor.process_queries([query]).to(self.model.device) | |
embeddings_query = self.model(**batch_query) | |
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
# Run scoring | |
scores = self.processor.score(qs, self.embds, device=self.device)[0] | |
top_k_idx = scores.topk(k).indices.tolist() | |
print("Top Scores:") | |
[print(f"Page {self.pages[idx].metadata.get('page_id')}: {scores[idx]}") for idx in top_k_idx] | |
# Get the top k results | |
results = [self.pages[idx] for idx in top_k_idx] | |
return results | |
def generate_answer(self, query: str, docs: list, api_key: str = None): | |
"""Generates an answer based on the query and the retrieved documents.""" | |
RAG_SYSTEM_PROMPT = \ | |
""" You are a smart assistant designed to answer questions about a PDF document. | |
You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context. | |
Use them to construct a response to the question, and cite your sources. | |
Use the following citation format: | |
"Some information from a first document [1, p.Page Number]. Some information from the same first document but at a different page [1, p.Page Number]. Some more information from another document [2, p.Page Number]. | |
... | |
Sources: | |
[1] Document Title | |
[2] Another Document Title" | |
You can answer using information contained in plots and figures if necessary. | |
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. | |
Give detailed answers, only containing info in the pages you are given. | |
Answer in the same language as the query.""" | |
result = self.query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT) | |
return result | |
def search(self, query: str, k: int = 1, api_key: str = None) -> tuple: | |
"""Searches for the most relevant pages based on the query.""" | |
print(f"Searching for query: {query}") | |
# Retrieve the top k documents | |
context = self.retrieve(query, k) | |
# Generate response from GPT-4o-mini | |
rag_answer = self.generate_answer( | |
query=query, | |
docs=context, | |
api_key=api_key | |
) | |
return context, rag_answer.content | |
def forward(self, query: str, k: int = 1, api_key: str = None) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
# Online indexing | |
# if files: | |
# _ = self.index(files, api_key) | |
# Retrieve the top k documents and generate response | |
return self.search( | |
query=query, | |
k=k, | |
api_key=api_key | |
)[1] |