Decider-MCP / app.py
Kian Kyars
Add error handling and fallback device mapping for model loading
5c65737
from pathlib import Path
from urllib.request import urlopen
from uuid import uuid4
import modal
MINUTES = 60
app = modal.App("chat-with-pdf")
CACHE_DIR = "/hf-cache"
model_image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("git")
.pip_install(
[
"transformers>=4.45.0",
"torch==2.4.1",
"torchvision==0.19.1",
"git+https://github.com/illuin-tech/colpali.git@782edcd50108d1842d154730ad3ce72476a2d17d",
"hf_transfer==0.1.8",
"qwen-vl-utils==0.0.8",
]
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": CACHE_DIR})
)
with model_image.imports():
import torch
from colpali_engine.models import ColQwen2, ColQwen2Processor
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
MODEL_REVISION = "aca78372505e6cb469c4fa6a35c60265b00ff5a4"
sessions = modal.Dict.from_name("colqwen-chat-sessions", create_if_missing=True)
class Session:
def __init__(self):
self.images = None
self.messages = []
self.pdf_embeddings = None
pdf_volume = modal.Volume.from_name("colqwen-chat-pdfs", create_if_missing=True)
PDF_ROOT = Path("/vol/pdfs/")
cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)
@app.function(
image=model_image, volumes={CACHE_DIR: cache_volume}, timeout=20 * MINUTES
)
def download_model():
from huggingface_hub import snapshot_download
result = snapshot_download(
MODEL_NAME,
revision=MODEL_REVISION,
ignore_patterns=["*.pt", "*.bin"],
)
print(f"Downloaded model weights to {result}")
@app.cls(
image=model_image,
gpu="B200",
scaledown_window=10 * MINUTES,
volumes={"/vol/pdfs/": pdf_volume, CACHE_DIR: cache_volume},
)
class Model:
@modal.enter()
def load_models(self):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_OFFLINE"] = "0"
# Load ColQwen2 with explicit configuration
try:
self.colqwen2_model = ColQwen2.from_pretrained(
"vidore/colqwen2-v0.1",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
except Exception as e:
print(f"Error loading ColQwen2: {e}")
# Fallback to CPU loading then move to GPU
self.colqwen2_model = ColQwen2.from_pretrained(
"vidore/colqwen2-v0.1",
torch_dtype=torch.bfloat16,
device_map=None,
trust_remote_code=True,
)
self.colqwen2_model = self.colqwen2_model.to("cuda:0")
self.colqwen2_processor = ColQwen2Processor.from_pretrained(
"vidore/colqwen2-v0.1"
)
# Load Qwen2-VL with explicit configuration
try:
self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
low_cpu_mem_usage=True,
)
except Exception as e:
print(f"Error loading Qwen2VL: {e}")
# Fallback approach
self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map=None,
)
self.qwen2_vl_model = self.qwen2_vl_model.to("cuda:0")
self.qwen2_vl_processor = AutoProcessor.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION,
trust_remote_code=True
)
@modal.method()
def index_pdf(self, session_id, target: bytes | list):
session = sessions.get(session_id)
if session is None:
session = Session()
if isinstance(target, bytes):
images = convert_pdf_to_images.remote(target)
else:
images = target
session_dir = PDF_ROOT / f"{session_id}"
session_dir.mkdir(exist_ok=True, parents=True)
for ii, image in enumerate(images):
filename = session_dir / f"{str(ii).zfill(3)}.jpg"
image.save(filename)
BATCH_SZ = 4
pdf_embeddings = []
batches = [images[i : i + BATCH_SZ] for i in range(0, len(images), BATCH_SZ)]
for batch in batches:
batch_images = self.colqwen2_processor.process_images(batch).to(
self.colqwen2_model.device
)
pdf_embeddings += list(self.colqwen2_model(**batch_images).to("cpu"))
session.pdf_embeddings = pdf_embeddings
sessions[session_id] = session
@modal.method()
def respond_to_message(self, session_id, message):
session = sessions.get(session_id)
if session is None:
session = Session()
pdf_volume.reload()
images = (PDF_ROOT / str(session_id)).glob("*.jpg")
images = list(sorted(images, key=lambda p: int(p.stem)))
if not images:
return "Please upload a PDF first"
elif session.pdf_embeddings is None:
return "Indexing PDF..."
relevant_image = self.get_relevant_image(message, session, images)
output_text = self.generate_response(message, session, relevant_image)
append_to_messages(message, session, user_type="user")
append_to_messages(output_text, session, user_type="assistant")
sessions[session_id] = session
return output_text
def get_relevant_image(self, message, session, images):
import PIL
batch_queries = self.colqwen2_processor.process_queries([message]).to(
self.colqwen2_model.device
)
query_embeddings = self.colqwen2_model(**batch_queries)
scores = self.colqwen2_processor.score_multi_vector(
query_embeddings, session.pdf_embeddings
)[0]
max_index = max(range(len(scores)), key=lambda index: scores[index])
return PIL.Image.open(images[max_index])
def generate_response(self, message, session, image):
chatbot_message = get_chatbot_message_with_image(message, image)
query = self.qwen2_vl_processor.apply_chat_template(
[*session.messages, chatbot_message],
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info([chatbot_message])
inputs = self.qwen2_vl_processor(
text=[query],
images=image_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda:0")
generated_ids = self.qwen2_vl_model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.qwen2_vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return output_text
pdf_image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("poppler-utils")
.pip_install("pdf2image==1.17.0", "pillow==10.4.0")
)
@app.function(image=pdf_image)
def convert_pdf_to_images(pdf_bytes):
from pdf2image import convert_from_bytes
images = convert_from_bytes(pdf_bytes, fmt="jpeg")
return images
@app.local_entrypoint()
def main(question: str = None, pdf_path: str = None, session_id: str = None):
model = Model()
if session_id is None:
session_id = str(uuid4())
print("Starting a new session with id", session_id)
if pdf_path is None:
pdf_path = "https://arxiv.org/pdf/1706.03762"
if pdf_path.startswith("http"):
pdf_bytes = urlopen(pdf_path).read()
else:
pdf_path = Path(pdf_path)
pdf_bytes = pdf_path.read_bytes()
print("Indexing PDF from", pdf_path)
model.index_pdf.remote(session_id, pdf_bytes)
else:
if pdf_path is not None:
raise ValueError("Start a new session to chat with a new PDF")
print("Resuming session with id", session_id)
if question is None:
question = "What is this document about?"
print("QUESTION:", question)
print(model.respond_to_message.remote(session_id, question))
web_image = pdf_image.pip_install(
"fastapi[standard]==0.115.4",
"pydantic==2.9.2",
"starlette==0.41.2",
"gradio==4.44.1",
"pillow==10.4.0",
"gradio-pdf==0.0.15",
"pdf2image==1.17.0",
)
@app.function(
image=web_image,
max_containers=1,
)
@modal.concurrent(max_inputs=1000)
@modal.asgi_app()
def ui():
import uuid
import gradio as gr
from fastapi import FastAPI
from gradio.routes import mount_gradio_app
from gradio_pdf import PDF
from pdf2image import convert_from_path
web_app = FastAPI()
model = Model()
def upload_pdf(path, session_id):
if session_id == "" or session_id is None:
session_id = str(uuid.uuid4())
images = convert_from_path(path)
model.index_pdf.remote(session_id, images)
return session_id
def respond_to_message(message, _, session_id):
return model.respond_to_message.remote(session_id, message)
with gr.Blocks(theme="soft") as demo:
session_id = gr.State("")
gr.Markdown("# Chat with PDF")
with gr.Row():
with gr.Column(scale=1):
gr.ChatInterface(
fn=respond_to_message,
additional_inputs=[session_id],
retry_btn=None,
undo_btn=None,
clear_btn=None,
)
with gr.Column(scale=1):
pdf = PDF(
label="Upload a PDF",
)
pdf.upload(upload_pdf, [pdf, session_id], session_id)
return mount_gradio_app(app=web_app, blocks=demo, path="/")
def get_chatbot_message_with_image(message, image):
return {
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": message},
],
}
def append_to_messages(message, session, user_type="user"):
session.messages.append(
{
"role": user_type,
"content": {"type": "text", "text": message},
},
)