multimodal-rag / app.py
deepakkarkala's picture
Logging to UI
0e9898e
raw
history blame
4.39 kB
import io
import logging
import os
import uuid
import streamlit as st
import torch
from byaldi import RAGMultiModalModel
from pdf2image import convert_from_bytes
from PIL import Image
from transformers import (AutoModelForVision2Seq, AutoProcessor,
BitsAndBytesConfig)
from transformers.image_utils import load_image
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Capture logs
log_stream = io.StringIO()
logging.basicConfig(stream=log_stream, level=logging.INFO)
if "session_id" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID
@st.cache_resource # Streamlit Caching decorator
def load_model_embedding():
#docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colsmolvlm-alpha")
#docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0")
docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
return docs_retrieval_model
model_embedding = load_model_embedding()
@st.cache_resource # Streamlit Caching decorator
def load_model_vlm():
checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
processor = AutoProcessor.from_pretrained(checkpoint)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForVision2Seq.from_pretrained(
checkpoint,
#torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
return model, processor
model_vlm, processor_vlm = load_model_vlm()
def save_images_to_local(dataset, output_folder="data/"):
os.makedirs(output_folder, exist_ok=True)
for image_id, image in enumerate(dataset):
#if isinstance(image, str):
# image = Image.open(image)
output_path = os.path.join(output_folder, f"image_{image_id}.png")
#image = Image.open(io.BytesIO(image_data))
image.save(output_path, format="PNG")
# Home page UI
with st.sidebar:
"[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
st.title("πŸ“ Image Q&A with VLM")
st.text_area("Logs:", log_stream.getvalue(), height=200)
uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf"))
query = st.text_input(
"Ask something about the image",
placeholder="Can you describe me the image ?",
disabled=not uploaded_pdf,
)
images = []
images_folder = "data/" + st.session_state["session_id"] + "/"
index_name = "index_" + st.session_state["session_id"]
if uploaded_pdf and "is_index_complete" not in st.session_state:
images = convert_from_bytes(uploaded_pdf.getvalue())
save_images_to_local(images, output_folder=images_folder)
# index documents using the document retrieval model
model_embedding.index(
input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True
)
logging.info(f"{len(images)} number of images extracted from PDF and indexed")
st.session_state["is_index_complete"] = True
if uploaded_pdf and query:
docs_retrieved = model_embedding.search(query, k=1)
logging.info(f"{len(docs_retrieved)} number of images retrieved as relevant to query")
image_id = docs_retrieved[0]["doc_id"]
logging.info(f"Image id:{image_id} retrieved" )
image_similar_to_query = images[image_id]
# Create input messages
system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
chat_template = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": query}
]
},
]
# Prepare inputs
prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True)
inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate outputs
generated_ids = model_vlm.generate(**inputs, max_new_tokens=500)
#generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
generated_texts = processor_vlm.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
response = generated_texts[0]
st.write(response)