from pathlib import Path from urllib.request import urlopen from uuid import uuid4 import modal MINUTES = 60 app = modal.App("chat-with-pdf") CACHE_DIR = "/hf-cache" model_image = ( modal.Image.debian_slim(python_version="3.12") .apt_install("git") .pip_install( [ "transformers>=4.45.0", "torch==2.4.1", "torchvision==0.19.1", "git+https://github.com/illuin-tech/colpali.git@782edcd50108d1842d154730ad3ce72476a2d17d", "hf_transfer==0.1.8", "qwen-vl-utils==0.0.8", ] ) .env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": CACHE_DIR}) ) with model_image.imports(): import torch from colpali_engine.models import ColQwen2, ColQwen2Processor from qwen_vl_utils import process_vision_info from transformers import AutoProcessor, Qwen2VLForConditionalGeneration MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" MODEL_REVISION = "aca78372505e6cb469c4fa6a35c60265b00ff5a4" sessions = modal.Dict.from_name("colqwen-chat-sessions", create_if_missing=True) class Session: def __init__(self): self.images = None self.messages = [] self.pdf_embeddings = None pdf_volume = modal.Volume.from_name("colqwen-chat-pdfs", create_if_missing=True) PDF_ROOT = Path("/vol/pdfs/") cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) @app.function( image=model_image, volumes={CACHE_DIR: cache_volume}, timeout=20 * MINUTES ) def download_model(): from huggingface_hub import snapshot_download result = snapshot_download( MODEL_NAME, revision=MODEL_REVISION, ignore_patterns=["*.pt", "*.bin"], ) print(f"Downloaded model weights to {result}") @app.cls( image=model_image, gpu="B200", scaledown_window=10 * MINUTES, volumes={"/vol/pdfs/": pdf_volume, CACHE_DIR: cache_volume}, ) class Model: @modal.enter() def load_models(self): import os os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TRANSFORMERS_OFFLINE"] = "0" # Load ColQwen2 with explicit configuration try: self.colqwen2_model = ColQwen2.from_pretrained( "vidore/colqwen2-v0.1", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ) except Exception as e: print(f"Error loading ColQwen2: {e}") # Fallback to CPU loading then move to GPU self.colqwen2_model = ColQwen2.from_pretrained( "vidore/colqwen2-v0.1", torch_dtype=torch.bfloat16, device_map=None, trust_remote_code=True, ) self.colqwen2_model = self.colqwen2_model.to("cuda:0") self.colqwen2_processor = ColQwen2Processor.from_pretrained( "vidore/colqwen2-v0.1" ) # Load Qwen2-VL with explicit configuration try: self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_NAME, revision=MODEL_REVISION, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", low_cpu_mem_usage=True, ) except Exception as e: print(f"Error loading Qwen2VL: {e}") # Fallback approach self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_NAME, revision=MODEL_REVISION, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map=None, ) self.qwen2_vl_model = self.qwen2_vl_model.to("cuda:0") self.qwen2_vl_processor = AutoProcessor.from_pretrained( MODEL_NAME, revision=MODEL_REVISION, trust_remote_code=True ) @modal.method() def index_pdf(self, session_id, target: bytes | list): session = sessions.get(session_id) if session is None: session = Session() if isinstance(target, bytes): images = convert_pdf_to_images.remote(target) else: images = target session_dir = PDF_ROOT / f"{session_id}" session_dir.mkdir(exist_ok=True, parents=True) for ii, image in enumerate(images): filename = session_dir / f"{str(ii).zfill(3)}.jpg" image.save(filename) BATCH_SZ = 4 pdf_embeddings = [] batches = [images[i : i + BATCH_SZ] for i in range(0, len(images), BATCH_SZ)] for batch in batches: batch_images = self.colqwen2_processor.process_images(batch).to( self.colqwen2_model.device ) pdf_embeddings += list(self.colqwen2_model(**batch_images).to("cpu")) session.pdf_embeddings = pdf_embeddings sessions[session_id] = session @modal.method() def respond_to_message(self, session_id, message): session = sessions.get(session_id) if session is None: session = Session() pdf_volume.reload() images = (PDF_ROOT / str(session_id)).glob("*.jpg") images = list(sorted(images, key=lambda p: int(p.stem))) if not images: return "Please upload a PDF first" elif session.pdf_embeddings is None: return "Indexing PDF..." relevant_image = self.get_relevant_image(message, session, images) output_text = self.generate_response(message, session, relevant_image) append_to_messages(message, session, user_type="user") append_to_messages(output_text, session, user_type="assistant") sessions[session_id] = session return output_text def get_relevant_image(self, message, session, images): import PIL batch_queries = self.colqwen2_processor.process_queries([message]).to( self.colqwen2_model.device ) query_embeddings = self.colqwen2_model(**batch_queries) scores = self.colqwen2_processor.score_multi_vector( query_embeddings, session.pdf_embeddings )[0] max_index = max(range(len(scores)), key=lambda index: scores[index]) return PIL.Image.open(images[max_index]) def generate_response(self, message, session, image): chatbot_message = get_chatbot_message_with_image(message, image) query = self.qwen2_vl_processor.apply_chat_template( [*session.messages, chatbot_message], tokenize=False, add_generation_prompt=True, ) image_inputs, _ = process_vision_info([chatbot_message]) inputs = self.qwen2_vl_processor( text=[query], images=image_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda:0") generated_ids = self.qwen2_vl_model.generate(**inputs, max_new_tokens=512) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.qwen2_vl_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] return output_text pdf_image = ( modal.Image.debian_slim(python_version="3.12") .apt_install("poppler-utils") .pip_install("pdf2image==1.17.0", "pillow==10.4.0") ) @app.function(image=pdf_image) def convert_pdf_to_images(pdf_bytes): from pdf2image import convert_from_bytes images = convert_from_bytes(pdf_bytes, fmt="jpeg") return images @app.local_entrypoint() def main(question: str = None, pdf_path: str = None, session_id: str = None): model = Model() if session_id is None: session_id = str(uuid4()) print("Starting a new session with id", session_id) if pdf_path is None: pdf_path = "https://arxiv.org/pdf/1706.03762" if pdf_path.startswith("http"): pdf_bytes = urlopen(pdf_path).read() else: pdf_path = Path(pdf_path) pdf_bytes = pdf_path.read_bytes() print("Indexing PDF from", pdf_path) model.index_pdf.remote(session_id, pdf_bytes) else: if pdf_path is not None: raise ValueError("Start a new session to chat with a new PDF") print("Resuming session with id", session_id) if question is None: question = "What is this document about?" print("QUESTION:", question) print(model.respond_to_message.remote(session_id, question)) web_image = pdf_image.pip_install( "fastapi[standard]==0.115.4", "pydantic==2.9.2", "starlette==0.41.2", "gradio==4.44.1", "pillow==10.4.0", "gradio-pdf==0.0.15", "pdf2image==1.17.0", ) @app.function( image=web_image, max_containers=1, ) @modal.concurrent(max_inputs=1000) @modal.asgi_app() def ui(): import uuid import gradio as gr from fastapi import FastAPI from gradio.routes import mount_gradio_app from gradio_pdf import PDF from pdf2image import convert_from_path web_app = FastAPI() model = Model() def upload_pdf(path, session_id): if session_id == "" or session_id is None: session_id = str(uuid.uuid4()) images = convert_from_path(path) model.index_pdf.remote(session_id, images) return session_id def respond_to_message(message, _, session_id): return model.respond_to_message.remote(session_id, message) with gr.Blocks(theme="soft") as demo: session_id = gr.State("") gr.Markdown("# Chat with PDF") with gr.Row(): with gr.Column(scale=1): gr.ChatInterface( fn=respond_to_message, additional_inputs=[session_id], retry_btn=None, undo_btn=None, clear_btn=None, ) with gr.Column(scale=1): pdf = PDF( label="Upload a PDF", ) pdf.upload(upload_pdf, [pdf, session_id], session_id) return mount_gradio_app(app=web_app, blocks=demo, path="/") def get_chatbot_message_with_image(message, image): return { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": message}, ], } def append_to_messages(message, session, user_type="user"): session.messages.append( { "role": user_type, "content": {"type": "text", "text": message}, }, )