Spaces:
Sleeping
Sleeping
File size: 4,761 Bytes
7901fac 0ae684c f50b49c 5d75ec7 169b696 be730b6 7901fac 0ae684c 976cfb4 7901fac 169b696 7901fac 169b696 976cfb4 5781f7f fff6204 976cfb4 169b696 7901fac fff6204 7901fac 0ae684c fff6204 169b696 976cfb4 169b696 976cfb4 0ae684c 976cfb4 0ae684c 169b696 976cfb4 a5bdd0f f50b49c 169b696 a5bdd0f 169b696 976cfb4 a5bdd0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import streamlit as st
import torch
from PIL import Image
import tempfile
import os
import time
import json
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class Qwen2Wrapper:
def __init__(self, model_name="Qwen/Qwen2-VL-7B-Instruct", device="cpu"):
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float32, device_map=device)
self.processor = AutoProcessor.from_pretrained(model_name)
self.device = device
self.index = {}
def index_image(self, image_path, index_name, overwrite=True):
if index_name in self.index and not overwrite:
raise ValueError(f"Index {index_name} already exists. Use overwrite=True to replace.")
self.index[index_name] = {"image": Image.open(image_path), "extracted_text": None}
def search(self, query, index_name, k=1):
if index_name not in self.index:
raise ValueError(f"Index {index_name} does not exist.")
image = self.index[index_name]["image"]
if self.index[index_name]["extracted_text"] is None:
self.index[index_name]["extracted_text"] = self._extract_text(image)
return [{"metadata": {"ocr_text": self.index[index_name]["extracted_text"]}}]
def _extract_text(self, image):
conversation_history = []
def prompt_text(query):
conversation_history.append({
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": query},
],
})
messages = conversation_history[:]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(self.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
conversation_history.append({
"role": "system",
"content": {"type": "text", "text": output_text[0]},
})
return output_text[0]
return prompt_text("give me just the text extracted")
# Function to load Qwen2 model
@st.cache_resource
def load_qwen2_model():
return Qwen2Wrapper(device="cpu")
# Streamlit Interface
st.title("OCR with Qwen2 (Byaldi-style implementation)")
st.write("Upload an image for OCR processing (supports Hindi and English text)")
# Image uploader
image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if image:
img = Image.open(image)
st.image(img, caption="Uploaded Image", use_column_width=True)
# OCR Extraction
st.write("Extracting text from image using Qwen2...")
qwen2_model = load_qwen2_model()
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
img.save(temp_file, format="JPEG")
temp_file_path = temp_file.name
# Create a unique index name
unique_index_name = f"temp_index_{int(time.time())}"
# Index the image
qwen2_model.index_image(temp_file_path, unique_index_name)
# Perform search (which triggers text extraction)
ocr_results = qwen2_model.search("Extract all text from the image", unique_index_name)
extracted_text = ocr_results[0]["metadata"]["ocr_text"]
# Remove the temporary file
os.unlink(temp_file_path)
# Display results
st.subheader("Qwen2 OCR Result:")
st.text(extracted_text)
st.json(json.dumps({"extracted_text": extracted_text}, ensure_ascii=False, indent=2))
# Keyword search
st.subheader("Search in Extracted Text")
keywords = st.text_input("Enter keywords to search (separate multiple keywords with commas)")
if keywords:
search_keywords = [k.strip() for k in keywords.split(',')]
def search_text(text, keywords):
words = text.split()
results = [word for word in words if any(keyword.lower() in word.lower() for keyword in keywords)]
return results
search_results = search_text(extracted_text, search_keywords)
st.write("Qwen2 Search Results:")
st.write(", ".join(search_results) if search_results else "No matches found.") |