Spaces:
Configuration error
Configuration error
from pathlib import Path | |
from urllib.request import urlopen | |
from uuid import uuid4 | |
import modal | |
MINUTES = 60 | |
app = modal.App("chat-with-pdf") | |
CACHE_DIR = "/hf-cache" | |
model_image = ( | |
modal.Image.debian_slim(python_version="3.12") | |
.apt_install("git") | |
.pip_install( | |
[ | |
"transformers>=4.45.0", | |
"torch==2.4.1", | |
"torchvision==0.19.1", | |
"git+https://github.com/illuin-tech/colpali.git@782edcd50108d1842d154730ad3ce72476a2d17d", | |
"hf_transfer==0.1.8", | |
"qwen-vl-utils==0.0.8", | |
] | |
) | |
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": CACHE_DIR}) | |
) | |
with model_image.imports(): | |
import torch | |
from colpali_engine.models import ColQwen2, ColQwen2Processor | |
from qwen_vl_utils import process_vision_info | |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" | |
MODEL_REVISION = "aca78372505e6cb469c4fa6a35c60265b00ff5a4" | |
sessions = modal.Dict.from_name("colqwen-chat-sessions", create_if_missing=True) | |
class Session: | |
def __init__(self): | |
self.images = None | |
self.messages = [] | |
self.pdf_embeddings = None | |
pdf_volume = modal.Volume.from_name("colqwen-chat-pdfs", create_if_missing=True) | |
PDF_ROOT = Path("/vol/pdfs/") | |
cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) | |
def download_model(): | |
from huggingface_hub import snapshot_download | |
result = snapshot_download( | |
MODEL_NAME, | |
revision=MODEL_REVISION, | |
ignore_patterns=["*.pt", "*.bin"], | |
) | |
print(f"Downloaded model weights to {result}") | |
class Model: | |
def load_models(self): | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
os.environ["TRANSFORMERS_OFFLINE"] = "0" | |
# Load ColQwen2 with explicit configuration | |
try: | |
self.colqwen2_model = ColQwen2.from_pretrained( | |
"vidore/colqwen2-v0.1", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
) | |
except Exception as e: | |
print(f"Error loading ColQwen2: {e}") | |
# Fallback to CPU loading then move to GPU | |
self.colqwen2_model = ColQwen2.from_pretrained( | |
"vidore/colqwen2-v0.1", | |
torch_dtype=torch.bfloat16, | |
device_map=None, | |
trust_remote_code=True, | |
) | |
self.colqwen2_model = self.colqwen2_model.to("cuda:0") | |
self.colqwen2_processor = ColQwen2Processor.from_pretrained( | |
"vidore/colqwen2-v0.1" | |
) | |
# Load Qwen2-VL with explicit configuration | |
try: | |
self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_NAME, | |
revision=MODEL_REVISION, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
) | |
except Exception as e: | |
print(f"Error loading Qwen2VL: {e}") | |
# Fallback approach | |
self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_NAME, | |
revision=MODEL_REVISION, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map=None, | |
) | |
self.qwen2_vl_model = self.qwen2_vl_model.to("cuda:0") | |
self.qwen2_vl_processor = AutoProcessor.from_pretrained( | |
MODEL_NAME, | |
revision=MODEL_REVISION, | |
trust_remote_code=True | |
) | |
def index_pdf(self, session_id, target: bytes | list): | |
session = sessions.get(session_id) | |
if session is None: | |
session = Session() | |
if isinstance(target, bytes): | |
images = convert_pdf_to_images.remote(target) | |
else: | |
images = target | |
session_dir = PDF_ROOT / f"{session_id}" | |
session_dir.mkdir(exist_ok=True, parents=True) | |
for ii, image in enumerate(images): | |
filename = session_dir / f"{str(ii).zfill(3)}.jpg" | |
image.save(filename) | |
BATCH_SZ = 4 | |
pdf_embeddings = [] | |
batches = [images[i : i + BATCH_SZ] for i in range(0, len(images), BATCH_SZ)] | |
for batch in batches: | |
batch_images = self.colqwen2_processor.process_images(batch).to( | |
self.colqwen2_model.device | |
) | |
pdf_embeddings += list(self.colqwen2_model(**batch_images).to("cpu")) | |
session.pdf_embeddings = pdf_embeddings | |
sessions[session_id] = session | |
def respond_to_message(self, session_id, message): | |
session = sessions.get(session_id) | |
if session is None: | |
session = Session() | |
pdf_volume.reload() | |
images = (PDF_ROOT / str(session_id)).glob("*.jpg") | |
images = list(sorted(images, key=lambda p: int(p.stem))) | |
if not images: | |
return "Please upload a PDF first" | |
elif session.pdf_embeddings is None: | |
return "Indexing PDF..." | |
relevant_image = self.get_relevant_image(message, session, images) | |
output_text = self.generate_response(message, session, relevant_image) | |
append_to_messages(message, session, user_type="user") | |
append_to_messages(output_text, session, user_type="assistant") | |
sessions[session_id] = session | |
return output_text | |
def get_relevant_image(self, message, session, images): | |
import PIL | |
batch_queries = self.colqwen2_processor.process_queries([message]).to( | |
self.colqwen2_model.device | |
) | |
query_embeddings = self.colqwen2_model(**batch_queries) | |
scores = self.colqwen2_processor.score_multi_vector( | |
query_embeddings, session.pdf_embeddings | |
)[0] | |
max_index = max(range(len(scores)), key=lambda index: scores[index]) | |
return PIL.Image.open(images[max_index]) | |
def generate_response(self, message, session, image): | |
chatbot_message = get_chatbot_message_with_image(message, image) | |
query = self.qwen2_vl_processor.apply_chat_template( | |
[*session.messages, chatbot_message], | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
image_inputs, _ = process_vision_info([chatbot_message]) | |
inputs = self.qwen2_vl_processor( | |
text=[query], | |
images=image_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = inputs.to("cuda:0") | |
generated_ids = self.qwen2_vl_model.generate(**inputs, max_new_tokens=512) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] | |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = self.qwen2_vl_processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
)[0] | |
return output_text | |
pdf_image = ( | |
modal.Image.debian_slim(python_version="3.12") | |
.apt_install("poppler-utils") | |
.pip_install("pdf2image==1.17.0", "pillow==10.4.0") | |
) | |
def convert_pdf_to_images(pdf_bytes): | |
from pdf2image import convert_from_bytes | |
images = convert_from_bytes(pdf_bytes, fmt="jpeg") | |
return images | |
def main(question: str = None, pdf_path: str = None, session_id: str = None): | |
model = Model() | |
if session_id is None: | |
session_id = str(uuid4()) | |
print("Starting a new session with id", session_id) | |
if pdf_path is None: | |
pdf_path = "https://arxiv.org/pdf/1706.03762" | |
if pdf_path.startswith("http"): | |
pdf_bytes = urlopen(pdf_path).read() | |
else: | |
pdf_path = Path(pdf_path) | |
pdf_bytes = pdf_path.read_bytes() | |
print("Indexing PDF from", pdf_path) | |
model.index_pdf.remote(session_id, pdf_bytes) | |
else: | |
if pdf_path is not None: | |
raise ValueError("Start a new session to chat with a new PDF") | |
print("Resuming session with id", session_id) | |
if question is None: | |
question = "What is this document about?" | |
print("QUESTION:", question) | |
print(model.respond_to_message.remote(session_id, question)) | |
web_image = pdf_image.pip_install( | |
"fastapi[standard]==0.115.4", | |
"pydantic==2.9.2", | |
"starlette==0.41.2", | |
"gradio==4.44.1", | |
"pillow==10.4.0", | |
"gradio-pdf==0.0.15", | |
"pdf2image==1.17.0", | |
) | |
def ui(): | |
import uuid | |
import gradio as gr | |
from fastapi import FastAPI | |
from gradio.routes import mount_gradio_app | |
from gradio_pdf import PDF | |
from pdf2image import convert_from_path | |
web_app = FastAPI() | |
model = Model() | |
def upload_pdf(path, session_id): | |
if session_id == "" or session_id is None: | |
session_id = str(uuid.uuid4()) | |
images = convert_from_path(path) | |
model.index_pdf.remote(session_id, images) | |
return session_id | |
def respond_to_message(message, _, session_id): | |
return model.respond_to_message.remote(session_id, message) | |
with gr.Blocks(theme="soft") as demo: | |
session_id = gr.State("") | |
gr.Markdown("# Chat with PDF") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.ChatInterface( | |
fn=respond_to_message, | |
additional_inputs=[session_id], | |
retry_btn=None, | |
undo_btn=None, | |
clear_btn=None, | |
) | |
with gr.Column(scale=1): | |
pdf = PDF( | |
label="Upload a PDF", | |
) | |
pdf.upload(upload_pdf, [pdf, session_id], session_id) | |
return mount_gradio_app(app=web_app, blocks=demo, path="/") | |
def get_chatbot_message_with_image(message, image): | |
return { | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": message}, | |
], | |
} | |
def append_to_messages(message, session, user_type="user"): | |
session.messages.append( | |
{ | |
"role": user_type, | |
"content": {"type": "text", "text": message}, | |
}, | |
) |