lukiod's picture
Update app.py
976cfb4 verified
import streamlit as st
import torch
from PIL import Image
import tempfile
import os
import time
import json
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class Qwen2Wrapper:
def __init__(self, model_name="Qwen/Qwen2-VL-7B-Instruct", device="cpu"):
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float32, device_map=device)
self.processor = AutoProcessor.from_pretrained(model_name)
self.device = device
self.index = {}
def index_image(self, image_path, index_name, overwrite=True):
if index_name in self.index and not overwrite:
raise ValueError(f"Index {index_name} already exists. Use overwrite=True to replace.")
self.index[index_name] = {"image": Image.open(image_path), "extracted_text": None}
def search(self, query, index_name, k=1):
if index_name not in self.index:
raise ValueError(f"Index {index_name} does not exist.")
image = self.index[index_name]["image"]
if self.index[index_name]["extracted_text"] is None:
self.index[index_name]["extracted_text"] = self._extract_text(image)
return [{"metadata": {"ocr_text": self.index[index_name]["extracted_text"]}}]
def _extract_text(self, image):
conversation_history = []
def prompt_text(query):
conversation_history.append({
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": query},
],
})
messages = conversation_history[:]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(self.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
conversation_history.append({
"role": "system",
"content": {"type": "text", "text": output_text[0]},
})
return output_text[0]
return prompt_text("give me just the text extracted")
# Function to load Qwen2 model
@st.cache_resource
def load_qwen2_model():
return Qwen2Wrapper(device="cpu")
# Streamlit Interface
st.title("OCR with Qwen2 (Byaldi-style implementation)")
st.write("Upload an image for OCR processing (supports Hindi and English text)")
# Image uploader
image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if image:
img = Image.open(image)
st.image(img, caption="Uploaded Image", use_column_width=True)
# OCR Extraction
st.write("Extracting text from image using Qwen2...")
qwen2_model = load_qwen2_model()
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
img.save(temp_file, format="JPEG")
temp_file_path = temp_file.name
# Create a unique index name
unique_index_name = f"temp_index_{int(time.time())}"
# Index the image
qwen2_model.index_image(temp_file_path, unique_index_name)
# Perform search (which triggers text extraction)
ocr_results = qwen2_model.search("Extract all text from the image", unique_index_name)
extracted_text = ocr_results[0]["metadata"]["ocr_text"]
# Remove the temporary file
os.unlink(temp_file_path)
# Display results
st.subheader("Qwen2 OCR Result:")
st.text(extracted_text)
st.json(json.dumps({"extracted_text": extracted_text}, ensure_ascii=False, indent=2))
# Keyword search
st.subheader("Search in Extracted Text")
keywords = st.text_input("Enter keywords to search (separate multiple keywords with commas)")
if keywords:
search_keywords = [k.strip() for k in keywords.split(',')]
def search_text(text, keywords):
words = text.split()
results = [word for word in words if any(keyword.lower() in word.lower() for keyword in keywords)]
return results
search_results = search_text(extracted_text, search_keywords)
st.write("Qwen2 Search Results:")
st.write(", ".join(search_results) if search_results else "No matches found.")