Gizachew commited on
Commit
84982b3
·
verified ·
1 Parent(s): 9d93df6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -176
app.py CHANGED
@@ -1,201 +1,70 @@
1
  # app.py
2
 
3
- import os
4
  import torch
5
- import torch.nn as nn
6
- from torchvision import transforms
7
  from PIL import Image
8
- import numpy as np
9
- import gradio as gr
10
- import timm
11
- import matplotlib.pyplot as plt
12
- import matplotlib.patches as patches
13
-
14
- # Optional: If integrating OCR
15
- # import pytesseract
16
-
17
- # Define the Detection Model Architecture
18
- class ViTDetectionModel(nn.Module):
19
- def __init__(self, num_queries=100, hidden_dim=768):
20
- """
21
- Initializes the ViTDetectionModel.
22
-
23
- Args:
24
- num_queries (int, optional): Number of detection queries. Defaults to 100.
25
- hidden_dim (int, optional): Hidden dimension size. Defaults to 768.
26
- """
27
- super(ViTDetectionModel, self).__init__()
28
- # Configure the ViT model to output features only
29
- self.vit = timm.create_model(
30
- 'vit_base_patch16_224',
31
- pretrained=False, # Set to False since we are loading a trained model
32
- num_classes=0, # Disable classification head
33
- features_only=True, # Return feature maps
34
- out_indices=(11,) # Get the last feature map
35
- )
36
- self.query_embed = nn.Embedding(num_queries, hidden_dim)
37
- self.fc_bbox = nn.Linear(hidden_dim, 8) # 4 points (x, y) for quadrilateral
38
- self.fc_class = nn.Linear(hidden_dim, 1) # Binary classification
39
-
40
- def forward(self, x):
41
- """
42
- Forward pass of the detection model.
43
-
44
- Args:
45
- x (Tensor): Input images [batch, 3, H, W].
46
-
47
- Returns:
48
- Tuple[Tensor, Tensor]: Predicted bounding boxes and class scores.
49
- """
50
- # Retrieve the feature map
51
- features = self.vit(x)[0] # [batch, hidden_dim, H*W]
52
-
53
- if features.dim() == 3:
54
- batch_size, hidden_dim, num_patches = features.shape
55
- grid_size = int(np.sqrt(num_patches))
56
- if grid_size * grid_size != num_patches:
57
- raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
58
- H, W = grid_size, grid_size
59
- features = features.view(batch_size, hidden_dim, H, W)
60
- elif features.dim() == 4:
61
- batch_size, hidden_dim, H, W = features.shape
62
- else:
63
- raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
64
-
65
- # Flatten the spatial dimensions
66
- features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim]
67
 
68
- # Prepare query embeddings
69
- queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) # [batch, num_queries, hidden_dim]
 
70
 
71
- # Compute attention weights
72
- attn = torch.matmul(features, queries.transpose(-1, -2)) # [batch, H*W, num_queries]
73
- attn = torch.softmax(attn, dim=1) # Softmax over patches
74
 
75
- # Aggregate features based on attention
76
- output = torch.matmul(attn.transpose(-1, -2), features) # [batch, num_queries, hidden_dim]
 
77
 
78
- # Predict bounding boxes and classes
79
- bboxes = self.fc_bbox(output) # [batch, num_queries, 8]
80
- classes = self.fc_class(output) # [batch, num_queries, 1]
81
 
82
- return bboxes, classes
 
 
83
 
84
- # Function to Load the Trained Model
85
- def load_model(model_path, device):
86
- """
87
- Loads the trained detection model.
88
-
89
- Args:
90
- model_path (str): Path to the saved model state dictionary.
91
- device (torch.device): Device to load the model on.
92
-
93
- Returns:
94
- nn.Module: Loaded detection model.
95
- """
96
- model = ViTDetectionModel(num_queries=100, hidden_dim=768).to(device)
97
- model.load_state_dict(torch.load(model_path, map_location=device))
98
- model.eval()
99
- return model
100
 
101
- # Function to Perform Text Detection on an Image
102
- def detect_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
103
  """
104
- Detects text in the input image using the detection model.
105
-
106
- Args:
107
- image (PIL Image): Input image.
108
- model (nn.Module): Trained detection model.
109
- device (torch.device): Device to run the model on.
110
- max_boxes (int, optional): Maximum number of bounding boxes to return. Defaults to 100.
111
- confidence_threshold (float, optional): Threshold to filter detections. Defaults to 0.5.
112
-
113
- Returns:
114
- PIL Image: Image with detected bounding boxes drawn.
115
  """
116
- # Define transformation
117
- transform = transforms.Compose([
118
- transforms.Resize((224, 224)),
119
- transforms.ToTensor(),
120
- ])
121
-
122
  # Preprocess the image
123
- input_tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
124
 
125
- # Perform detection
126
  with torch.no_grad():
127
- pred_bboxes, pred_classes = model(input_tensor) # [1, num_queries, 8], [1, num_queries, 1]
128
-
129
- # Process predictions
130
- pred_bboxes = pred_bboxes.squeeze(0) # [num_queries, 8]
131
- pred_classes = pred_classes.squeeze(0) # [num_queries, 1]
132
- pred_classes_sigmoid = torch.sigmoid(pred_classes)
133
- high_conf_indices = (pred_classes_sigmoid > confidence_threshold).squeeze(1).nonzero(as_tuple=False).squeeze(1)
134
- selected_indices = high_conf_indices[:max_boxes]
135
- selected_bboxes = pred_bboxes[selected_indices] # [selected, 8]
136
-
137
- # Denormalize bounding boxes to original image size
138
- width, height = image.size
139
- scale_x = width / 224
140
- scale_y = height / 224
141
- boxes = selected_bboxes.cpu().numpy() * np.array([scale_x, scale_y] * 4) # [selected, 8]
142
-
143
- # Draw bounding boxes on the image
144
- fig, ax = plt.subplots(1, figsize=(12, 12))
145
- ax.imshow(image)
146
-
147
- for box in boxes:
148
- polygon = patches.Polygon(box.reshape(-1, 2), linewidth=2, edgecolor='r', facecolor='none')
149
- ax.add_patch(polygon)
150
 
151
- plt.axis('off')
152
- # Convert Matplotlib figure to PIL Image
153
- fig.canvas.draw()
154
- img_with_boxes = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
155
- plt.close(fig)
156
 
157
- return img_with_boxes
158
-
159
- # Optional: If integrating OCR with pytesseract
160
- # def detect_and_recognize_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
161
- # # Similar to detect_text but includes OCR steps
162
- # pass
163
-
164
- # Initialize the model
165
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166
- model_path = "finetuned_recog_model.pth" # Ensure this path matches where the model is stored
167
- model = load_model(model_path, device)
168
- print("Model loaded successfully.")
169
 
170
- # Define the Gradio Interface Function
171
- def gradio_detect(image):
172
  """
173
- Gradio interface function for text detection.
174
-
175
- Args:
176
- image (PIL Image): Uploaded image.
177
-
178
- Returns:
179
- PIL Image: Image with detected bounding boxes.
180
  """
181
- result_image = detect_text(image, model, device)
182
- return result_image
 
 
 
 
183
 
184
- # Create Gradio Interface
185
  iface = gr.Interface(
186
- fn=gradio_detect,
187
- inputs=gr.Image(type="pil"),
188
- outputs=gr.Image(type="pil"),
189
- title="Text Detection with ViT",
190
- description="Upload an image, and the model will detect and highlight text regions.",
191
- examples=[
192
- # You can add URLs or paths to example images here
193
- # "https://example.com/image1.jpg",
194
- # "https://example.com/image2.jpg",
195
- ],
196
- allow_flagging="never"
197
  )
198
 
199
- # Launch the Gradio App (Optional for local testing)
200
- # if __name__ == "__main__":
201
- # iface.launch()
 
1
  # app.py
2
 
3
+ import gradio as gr
4
  import torch
 
 
5
  from PIL import Image
6
+ from model import load_model
7
+ from utils import preprocess_image, decode_predictions
8
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Load the model (ensure the path is correct)
11
+ MODEL_PATH = "saved_models/finetuned/finetuned_recog_model.pth"
12
+ FONT_PATH = "fonts/NotoSansEthiopic-Regular.ttf" # Update the path to your font
13
 
14
+ # Check if model file exists
15
+ if not os.path.exists(MODEL_PATH):
16
+ raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")
17
 
18
+ # Check if font file exists
19
+ if not os.path.exists(FONT_PATH):
20
+ raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")
21
 
22
+ # Load the model
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ model = load_model(MODEL_PATH, device=device)
25
 
26
+ # Load the font for rendering Amharic text
27
+ from matplotlib import font_manager as fm
28
+ import matplotlib.pyplot as plt
29
 
30
+ ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def recognize_text(image: Image.Image):
 
33
  """
34
+ Function to recognize text from an image.
 
 
 
 
 
 
 
 
 
 
35
  """
 
 
 
 
 
 
36
  # Preprocess the image
37
+ input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
38
 
39
+ # Perform inference
40
  with torch.no_grad():
41
+ log_probs = model(input_tensor) # [H*W, 1, vocab_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Decode predictions
44
+ recognized_texts = decode_predictions(log_probs)
 
 
 
45
 
46
+ return recognized_texts[0]
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def display_image_with_text(image: Image.Image, recognized_text: str):
 
49
  """
50
+ Function to display the image with recognized text.
 
 
 
 
 
 
51
  """
52
+ plt.figure(figsize=(6,6))
53
+ plt.imshow(image)
54
+ plt.axis('off')
55
+ plt.title(f"Recognized Text: {recognized_text}", fontproperties=ethiopic_font)
56
+ plt.show()
57
+ return plt
58
 
59
+ # Define Gradio Interface
60
  iface = gr.Interface(
61
+ fn=recognize_text,
62
+ inputs=gr.inputs.Image(type="pil"),
63
+ outputs=gr.outputs.Textbox(),
64
+ title="Amharic Text Recognition",
65
+ description="Upload an image containing Amharic text, and the model will recognize and display the text."
 
 
 
 
 
 
66
  )
67
 
68
+ # Launch the Gradio app
69
+ if __name__ == "__main__":
70
+ iface.launch()