File size: 4,761 Bytes
7901fac
0ae684c
 
f50b49c
5d75ec7
169b696
 
be730b6
7901fac
0ae684c
976cfb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7901fac
169b696
7901fac
169b696
976cfb4
5781f7f
fff6204
976cfb4
169b696
7901fac
fff6204
 
7901fac
0ae684c
fff6204
 
 
169b696
976cfb4
 
 
169b696
976cfb4
 
 
 
 
 
0ae684c
976cfb4
 
 
 
 
 
 
 
 
 
0ae684c
169b696
976cfb4
a5bdd0f
 
f50b49c
169b696
 
 
 
 
 
 
 
 
 
 
a5bdd0f
169b696
976cfb4
a5bdd0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import torch
from PIL import Image
import tempfile
import os
import time
import json
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

class Qwen2Wrapper:
    def __init__(self, model_name="Qwen/Qwen2-VL-7B-Instruct", device="cpu"):
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float32, device_map=device)
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.device = device
        self.index = {}

    def index_image(self, image_path, index_name, overwrite=True):
        if index_name in self.index and not overwrite:
            raise ValueError(f"Index {index_name} already exists. Use overwrite=True to replace.")
        
        self.index[index_name] = {"image": Image.open(image_path), "extracted_text": None}

    def search(self, query, index_name, k=1):
        if index_name not in self.index:
            raise ValueError(f"Index {index_name} does not exist.")
        
        image = self.index[index_name]["image"]
        
        if self.index[index_name]["extracted_text"] is None:
            self.index[index_name]["extracted_text"] = self._extract_text(image)
        
        return [{"metadata": {"ocr_text": self.index[index_name]["extracted_text"]}}]

    def _extract_text(self, image):
        conversation_history = []
        
        def prompt_text(query):
            conversation_history.append({
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": query},
                ],
            })

            messages = conversation_history[:]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, _ = process_vision_info(messages)
            inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt")
            inputs = inputs.to(self.device)

            generated_ids = self.model.generate(**inputs, max_new_tokens=128)
            generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
            output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)

            conversation_history.append({
                "role": "system",
                "content": {"type": "text", "text": output_text[0]},
            })

            return output_text[0]

        return prompt_text("give me just the text extracted")

# Function to load Qwen2 model
@st.cache_resource
def load_qwen2_model():
    return Qwen2Wrapper(device="cpu")

# Streamlit Interface
st.title("OCR with Qwen2 (Byaldi-style implementation)")
st.write("Upload an image for OCR processing (supports Hindi and English text)")

# Image uploader
image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if image:
    img = Image.open(image)
    st.image(img, caption="Uploaded Image", use_column_width=True)

    # OCR Extraction
    st.write("Extracting text from image using Qwen2...")
    
    qwen2_model = load_qwen2_model()
    
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
        img.save(temp_file, format="JPEG")
        temp_file_path = temp_file.name

    # Create a unique index name
    unique_index_name = f"temp_index_{int(time.time())}"
    
    # Index the image
    qwen2_model.index_image(temp_file_path, unique_index_name)
    
    # Perform search (which triggers text extraction)
    ocr_results = qwen2_model.search("Extract all text from the image", unique_index_name)
    
    extracted_text = ocr_results[0]["metadata"]["ocr_text"]

    # Remove the temporary file
    os.unlink(temp_file_path)
    
    # Display results
    st.subheader("Qwen2 OCR Result:")
    st.text(extracted_text)
    st.json(json.dumps({"extracted_text": extracted_text}, ensure_ascii=False, indent=2))

    # Keyword search
    st.subheader("Search in Extracted Text")
    keywords = st.text_input("Enter keywords to search (separate multiple keywords with commas)")
    if keywords:
        search_keywords = [k.strip() for k in keywords.split(',')]
        
        def search_text(text, keywords):
            words = text.split()
            results = [word for word in words if any(keyword.lower() in word.lower() for keyword in keywords)]
            return results
        
        search_results = search_text(extracted_text, search_keywords)
        
        st.write("Qwen2 Search Results:")
        st.write(", ".join(search_results) if search_results else "No matches found.")