import base64 import pathlib import gradio as gr import spaces import torch from colpali_engine.models import ColPali, ColPaliProcessor from transformers.utils.import_utils import is_flash_attn_2_available from pdf2image import convert_from_path from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoProcessor, LlavaForConditionalGeneration PIXTAL_MODEL_ID = "mistral-community--pixtral-12b" PIXTRAL_MODEL_SNAPSHOT = "c2756cbbb9422eba9f6c5c439a214b0392dfc998" PIXTRAL_MODEL_PATH = ( pathlib.Path().home() / f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/snapshots/{PIXTRAL_MODEL_SNAPSHOT}" ) COLPALI_GEMMA_MODEL_ID = "vidore--colpaligemma-3b-pt-448-base" COLPALI_GEMMA_MODEL_SNAPSHOT = "30ab955d073de4a91dc5a288e8c97226647e3e5a" COLPALI_GEMMA_MODEL_PATH = ( pathlib.Path().home() / f".cache/huggingface/hub/models--{COLPALI_GEMMA_MODEL_ID}/snapshots/{COLPALI_GEMMA_MODEL_SNAPSHOT}" ) COLPALI_MODEL_ID = "vidore--colpali-v1.3" COLPALI_MODEL_SNAPSHOT = "1b5c8929330df1a66de441a9b5409a878f0de5b0" COLPALI_MODEL_PATH = ( pathlib.Path().home() / f".cache/huggingface/hub/models--{COLPALI_MODEL_ID}/snapshots/{COLPALI_MODEL_SNAPSHOT}" ) def image_to_base64(image_path): with open(image_path, "rb") as img: encoded_string = base64.b64encode(img.read()).decode("utf-8") return f"data:image/jpeg;base64,{encoded_string}" @spaces.GPU(duration=120) def pixtral_inference( images, text, ): if len(images) == 0: raise gr.Error("No images for generation") if text == "": raise gr.Error("No query for generation") print("LOADING MODEL") model = LlavaForConditionalGeneration.from_pretrained( PIXTRAL_MODEL_PATH, device_map="cuda" ) print("LOADING MODEL DONE") print("LOADING PROCESSOR") processor = AutoProcessor.from_pretrained(PIXTRAL_MODEL_PATH, use_fast=True) print("LOADING PROCESSOR DONE") chat = [ { "role": "user", "content": [{"type": "image", "url": image_to_base64(i[0])} for i in images] + [ {"type": "text", "content": text}, ], } ] inputs = processor.apply_chat_template( chat, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) print("GENERATING") generate_ids = model.generate(**inputs, max_new_tokens=256) print("GENERATING DONE") print("BATCH DECODE") output = processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] print("BATCH DECODE DONE") print(output) result = output return result @spaces.GPU(duration=120) def retrieve(query: str, ds, images, k): if len(images) == 0: raise gr.Error("No docs/images for retrieval") if query == "": raise gr.Error("No query for retrieval") model = ColPali.from_pretrained( COLPALI_GEMMA_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation=( "flash_attention_2" if is_flash_attn_2_available() else None ), ).eval() model.load_adapter(COLPALI_MODEL_PATH) model = model.eval() processor = ColPaliProcessor.from_pretrained(COLPALI_MODEL_PATH, use_fast=True) qs = [] with torch.no_grad(): batch_query = processor.process_queries([query]).to("cuda") embeddings_query = model.forward(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) scores = processor.score_multi_vector(qs, ds).numpy() top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] results = [] for idx in top_k_indices: results.append((images[idx], f"Score {scores[0][idx]:.2f}")) del model del processor torch.cuda.empty_cache() return results def index(files, ds): images = convert_files(files) return index_gpu(images, ds) def convert_files(files): images = [] for f in files: images.extend(convert_from_path(f, thread_count=4)) if len(images) >= 150: raise gr.Error("The number of images in the dataset should be less than 150.") return images @spaces.GPU(duration=120) def index_gpu(images, ds): model = ColPali.from_pretrained( COLPALI_GEMMA_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda", ).eval() model.load_adapter(COLPALI_MODEL_PATH) model = model.eval() processor = ColPaliProcessor.from_pretrained(COLPALI_MODEL_PATH, use_fast=True) # run inference - docs dataloader = DataLoader( images, batch_size=4, shuffle=False, collate_fn=lambda x: processor.process_images(x), ) for batch_doc in tqdm(dataloader): with torch.no_grad(): batch_doc = {k: v.to("cuda") for k, v in batch_doc.items()} embeddings_doc = model(**batch_doc) ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) del model del processor torch.cuda.empty_cache() return f"Uploaded and converted {len(images)} pages", ds, images def get_example(): return [ [["plants_and_people.pdf"], "What is the global population in 2050 ? "], [["plants_and_people.pdf"], "Where was Teosinte domesticated ?"], ] css = """ #title-container { margin: 0 auto; max-width: 800px; text-align: center; } #col-container { margin: 0 auto; max-width: 600px; } """ file = gr.File( file_types=[".pdf"], type="filepath", file_count="multiple", label="PDFs" ) query = gr.Textbox("", placeholder="Enter your query here", label="Query") with gr.Blocks( title="Document Question Answering with ColPali & Pixtral", theme=gr.themes.Soft(), css=css, ) as demo: with gr.Row(elem_id="title-container"): gr.Markdown("""# Document Question Answering with ColPali & Pixtral""") with gr.Column(elem_id="col-container"): with gr.Row(): gr.Examples( examples=get_example(), inputs=[file, query], ) with gr.Row(): with gr.Column(scale=2): gr.Markdown("## Index PDFs") file.render() convert_button = gr.Button("🔄 Run", variant="primary") message = gr.Textbox("Files not yet uploaded", label="Status") embeds = gr.State(value=[]) imgs = gr.State(value=[]) img_chunk = gr.State(value=[]) with gr.Column(scale=3): gr.Markdown("## Retrieve with ColPali and answer with Pixtral") query.render() k = gr.Slider( minimum=1, maximum=4, step=1, label="Number of docs to retrieve", value=1, ) answer_button = gr.Button("🏃 Run", variant="primary") output_gallery = gr.Gallery( label="Retrieved docs", height=400, show_label=True, interactive=False ) output = gr.Textbox(label="Answer", lines=2, interactive=False) convert_button.click( index, inputs=[file, embeds], outputs=[message, embeds, imgs] ) answer_button.click( retrieve, inputs=[query, embeds, imgs, k], outputs=[output_gallery] ).then(pixtral_inference, inputs=[output_gallery, query], outputs=[output]) if __name__ == "__main__": demo.queue(max_size=10).launch()