abdullahalmunem commited on
Commit
f81cfe2
·
1 Parent(s): ad3fa03

model added

Browse files
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:latest
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
7
+
8
+ COPY requirements.txt ./
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy your source code
12
+ COPY . .
13
+
14
+ # Expose port 7860 (Hugging Face Spaces default)
15
+ EXPOSE 7860
16
+
17
+ # Run both API and Gradio
18
+ CMD ["bash", "entrypoint.sh"]
api_onnx.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from fastapi import FastAPI, Request
4
+ from pydantic import BaseModel
5
+ from inference_onnx import get_transcription
6
+ import torch
7
+ import onnxruntime as ort
8
+ from config import *
9
+ from contextlib import asynccontextmanager
10
+
11
+ # Global session object (attached to app.state)
12
+ @asynccontextmanager
13
+ async def lifespan(app: FastAPI):
14
+ print("🔧 Loading model...")
15
+
16
+ app.state.device = torch.device('cpu')
17
+ app.state.tokenizer = MODELS["./distilbert-base-multilingual-cased"][1].from_pretrained("./distilbert-base-multilingual-cased")
18
+ app.state.token_style = MODELS["./distilbert-base-multilingual-cased"][3]
19
+
20
+ onnx_model_path = "./poc_onnx_model_punctuation_batch.onnx"
21
+ providers = ['CPUExecutionProvider']
22
+ # providers = ["CUDAExecutionProvider"]
23
+ # providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
24
+ sess_options = ort.SessionOptions()
25
+ app.state.session = ort.InferenceSession(onnx_model_path, providers=providers)
26
+
27
+ print("✅ ONNX model loaded into memory.")
28
+ yield
29
+ print("🧹 Shutting down...")
30
+
31
+ app = FastAPI(lifespan=lifespan)
32
+
33
+ punc_dict = {
34
+ '!': 'EXCLAMATION',
35
+ '?': 'QUESTION',
36
+ ',': 'COMMA',
37
+ ';': 'SEMICOLON',
38
+ ':': 'COLON',
39
+ '-': 'HYPHEN',
40
+ '।': 'DARI',
41
+ }
42
+ allowed_punctuations = set(punc_dict.keys())
43
+
44
+ def clean_and_normalize_text(text, remove_punctuations=False):
45
+ """Clean and normalize Bangla text with correct spacing"""
46
+ if remove_punctuations:
47
+ # Remove all allowed punctuations
48
+ cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text)
49
+ # Normalize spaces
50
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
51
+ return cleaned_text
52
+ else:
53
+ # Keep only allowed punctuations and Bangla letters/digits
54
+ chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
55
+ filtered_chunks = []
56
+
57
+ for chunk in chunks:
58
+ if chunk in allowed_punctuations:
59
+ filtered_chunks.append(chunk)
60
+ else:
61
+ # Clean text and preserve word boundaries
62
+ clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk)
63
+ clean_chunk = re.sub(r'\s+', ' ', clean_chunk) # Normalize internal spacing
64
+ clean_chunk = clean_chunk.strip()
65
+ if clean_chunk:
66
+ filtered_chunks.append(' ' + clean_chunk) # Add space before word chunks
67
+
68
+ # Join and clean up spacing
69
+ result = ''.join(filtered_chunks)
70
+ result = re.sub(r'\s+', ' ', result).strip()
71
+ return result
72
+
73
+ class TextInput(BaseModel):
74
+ text: str
75
+
76
+ @app.post("/punctuate")
77
+ async def punctuate_text(data: TextInput):
78
+ input_normalized = clean_and_normalize_text(data.text)
79
+ input_normalized = clean_and_normalize_text(input_normalized, remove_punctuations=True)
80
+ restored_text = get_transcription(input_normalized, app.state.session, app.state.tokenizer, app.state.device, app.state.token_style)
81
+ return {"restored_text": restored_text}
82
+
83
+ if __name__ == "__main__":
84
+ import uvicorn
85
+ uvicorn.run("api_onnx:app", host="0.0.0.0", port=5685, workers=1)
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import requests
4
+ import re
5
+ import time
6
+ import pandas as pd
7
+ from typing import Dict, Tuple, List, Optional
8
+
9
+ # Configuration
10
+ API_URL = "http://localhost:5685/punctuate"
11
+
12
+
13
+ punc_dict = {
14
+ '!': 'EXCLAMATION',
15
+ '?': 'QUESTION',
16
+ ',': 'COMMA',
17
+ ';': 'SEMICOLON',
18
+ ':': 'COLON',
19
+ '-': 'HYPHEN',
20
+ '।': 'DARI',
21
+ }
22
+
23
+ allowed_punctuations = set(punc_dict.keys())
24
+
25
+ def clean_and_normalize_text(text, remove_punctuations=False):
26
+ """Clean and normalize Bangla text with correct spacing"""
27
+ if remove_punctuations:
28
+ # Remove all allowed punctuations
29
+ cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text)
30
+ # Normalize spaces
31
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
32
+ return cleaned_text
33
+ else:
34
+ # Keep only allowed punctuations and Bangla letters/digits
35
+ chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
36
+ filtered_chunks = []
37
+
38
+ for chunk in chunks:
39
+ if chunk in allowed_punctuations:
40
+ filtered_chunks.append(chunk)
41
+ else:
42
+ # Clean text and preserve word boundaries
43
+ clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk)
44
+ clean_chunk = re.sub(r'\s+', ' ', clean_chunk) # Normalize internal spacing
45
+ clean_chunk = clean_chunk.strip()
46
+ if clean_chunk:
47
+ filtered_chunks.append(' ' + clean_chunk) # Add space before word chunks
48
+
49
+ # Join and clean up spacing
50
+ result = ''.join(filtered_chunks)
51
+ result = re.sub(r'\s+', ' ', result).strip()
52
+ return result
53
+
54
+ def restore_punctuation(text):
55
+ """Call the punctuation restoration API"""
56
+ try:
57
+ payload = {"text": text}
58
+ start_time = time.time()
59
+ response = requests.post(API_URL, json=payload)
60
+ end_time = time.time()
61
+
62
+ api_time = end_time - start_time
63
+
64
+ if response.status_code == 200:
65
+ restored_text = response.json().get("restored_text")
66
+ return restored_text, api_time
67
+ else:
68
+ return f"API Error: {response.status_code} - {response.text}", api_time
69
+ except Exception as e:
70
+ return f"Connection Error: {str(e)}", 0.0
71
+
72
+ def dummy_restore_punctuation(text):
73
+ """Dummy API call for demonstration when real API is not available"""
74
+ time.sleep(0.5) # Simulate API delay
75
+
76
+ # Simple dummy logic - add some punctuations randomly for demo
77
+ words = text.split()
78
+ if len(words) > 5:
79
+ words[2] = words[2] + ','
80
+ words[-1] = words[-1] + '?'
81
+ elif len(words) > 2:
82
+ words[-1] = words[-1] + '!'
83
+
84
+ return ' '.join(words), 0.5
85
+
86
+ def tokenize_with_punctuation(text):
87
+ """Tokenize text keeping punctuation separate using chunk-based approach"""
88
+ tokens = []
89
+ chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
90
+
91
+ for chunk in chunks:
92
+ if not chunk.strip():
93
+ continue
94
+
95
+ if chunk in allowed_punctuations:
96
+ # This is a punctuation
97
+ tokens.append(chunk)
98
+ else:
99
+ # This is text, split into words
100
+ words = chunk.strip().split()
101
+ for word in words:
102
+ if word.strip():
103
+ tokens.append(word.strip())
104
+
105
+ return tokens
106
+
107
+ def compare_texts(ground_truth, predicted):
108
+ """Compare ground truth and predicted text token by token with proper alignment"""
109
+ gt_tokens = tokenize_with_punctuation(ground_truth)
110
+ pred_tokens = tokenize_with_punctuation(predicted)
111
+
112
+ comparison_result = []
113
+ correct_puncs = {}
114
+ wrong_puncs = {}
115
+ gt_punc_counts = {}
116
+
117
+ # Count punctuations in ground truth
118
+ for token in gt_tokens:
119
+ if token in allowed_punctuations:
120
+ punc_name = punc_dict[token]
121
+ gt_punc_counts[punc_name] = gt_punc_counts.get(punc_name, 0) + 1
122
+
123
+ # Separate words and punctuations for better alignment
124
+ gt_words = [token for token in gt_tokens if token not in allowed_punctuations]
125
+ pred_words = [token for token in pred_tokens if token not in allowed_punctuations]
126
+
127
+ # Create position maps for punctuations
128
+ gt_punct_map = {} # word_index -> [punctuations after this word]
129
+ pred_punct_map = {} # word_index -> [punctuations after this word]
130
+
131
+ # Build ground truth punctuation map
132
+ word_idx = -1
133
+ for i, token in enumerate(gt_tokens):
134
+ if token not in allowed_punctuations:
135
+ word_idx += 1
136
+ else:
137
+ if word_idx not in gt_punct_map:
138
+ gt_punct_map[word_idx] = []
139
+ gt_punct_map[word_idx].append(token)
140
+
141
+ # Build predicted punctuation map
142
+ word_idx = -1
143
+ for i, token in enumerate(pred_tokens):
144
+ if token not in allowed_punctuations:
145
+ word_idx += 1
146
+ else:
147
+ if word_idx not in pred_punct_map:
148
+ pred_punct_map[word_idx] = []
149
+ pred_punct_map[word_idx].append(token)
150
+
151
+ # Compare words and punctuations
152
+ max_words = max(len(gt_words), len(pred_words))
153
+
154
+ for i in range(max_words):
155
+ # Add word
156
+ if i < len(gt_words) and i < len(pred_words):
157
+ if gt_words[i] == pred_words[i]:
158
+ comparison_result.append((gt_words[i], "correct", "black"))
159
+ else:
160
+ comparison_result.append((f"{gt_words[i]}→{pred_words[i]}", "word_diff", "orange"))
161
+ elif i < len(gt_words):
162
+ comparison_result.append((f"{gt_words[i]}→''", "missing_word", "red"))
163
+ elif i < len(pred_words):
164
+ comparison_result.append((f"''→{pred_words[i]}", "extra_word", "red"))
165
+
166
+ # Compare punctuations after this word
167
+ gt_puncs = gt_punct_map.get(i, [])
168
+ pred_puncs = pred_punct_map.get(i, [])
169
+
170
+ # Handle punctuation comparison
171
+ max_puncs = max(len(gt_puncs), len(pred_puncs))
172
+
173
+ for j in range(max_puncs):
174
+ if j < len(gt_puncs) and j < len(pred_puncs):
175
+ gt_punc = gt_puncs[j]
176
+ pred_punc = pred_puncs[j]
177
+
178
+ if gt_punc == pred_punc:
179
+ punc_name = punc_dict[gt_punc]
180
+ correct_puncs[punc_name] = correct_puncs.get(punc_name, 0) + 1
181
+ comparison_result.append((gt_punc, "correct", "green"))
182
+ else:
183
+ # Wrong punctuation
184
+ punc_name = punc_dict[gt_punc]
185
+ wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1
186
+ comparison_result.append((f"{gt_punc}→{pred_punc}", "wrong_punct", "red"))
187
+
188
+ elif j < len(gt_puncs):
189
+ # Missing punctuation
190
+ gt_punc = gt_puncs[j]
191
+ punc_name = punc_dict[gt_punc]
192
+ wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1
193
+ comparison_result.append((f"{gt_punc}→''", "missing_punct", "red"))
194
+
195
+ elif j < len(pred_puncs):
196
+ # Extra punctuation (not counted in wrong_puncs since it's not in GT)
197
+ pred_punc = pred_puncs[j]
198
+ comparison_result.append((f"''→{pred_punc}", "extra_punct", "red"))
199
+
200
+ return comparison_result, correct_puncs, wrong_puncs, gt_punc_counts
201
+
202
+ def create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts):
203
+ """Create evaluation table"""
204
+ table_data = []
205
+
206
+ for punc_name in gt_punc_counts.keys():
207
+ correct_count = correct_puncs.get(punc_name, 0)
208
+ wrong_count = wrong_puncs.get(punc_name, 0)
209
+ total_count = gt_punc_counts[punc_name]
210
+
211
+ table_data.append([
212
+ punc_name,
213
+ correct_count,
214
+ wrong_count,
215
+ total_count
216
+ ])
217
+
218
+ df = pd.DataFrame(table_data, columns=[
219
+ "Punctuation Name",
220
+ "Correctly Classified",
221
+ "Wrongly Classified",
222
+ "Count in Ground Truth"
223
+ ])
224
+
225
+ return df
226
+
227
+ def format_comparison_html(comparison_result):
228
+ """Format comparison result as HTML with improved display"""
229
+ html = "<div style='font-family: monospace; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 5px;'>"
230
+
231
+ for token, status, color in comparison_result:
232
+ if status == "correct" and color == "green":
233
+ # Correct punctuation
234
+ html += f"<span style='background-color: #d4edda; color: #155724; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>"
235
+ elif color == "red":
236
+ # Incorrect, missing, or extra punctuation/word
237
+ if "→''" in token:
238
+ # Missing punctuation or word
239
+ missing_item = token.split("→")[0]
240
+ html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{missing_item}→∅</span>"
241
+ elif "''→" in token:
242
+ # Extra punctuation or word
243
+ extra_item = token.split("→")[1]
244
+ html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>∅→{extra_item}</span>"
245
+ else:
246
+ # Wrong punctuation/word
247
+ html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>"
248
+ elif color == "orange":
249
+ # Word difference
250
+ html += f"<span style='background-color: #fff3cd; color: #856404; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>"
251
+ else:
252
+ # Correct word
253
+ html += f"<span style='padding: 2px 4px; margin: 1px;'>{token}</span>"
254
+
255
+ # Add space after each token
256
+ html += " "
257
+
258
+ html += "</div>"
259
+
260
+ # Add legend
261
+ html += """
262
+ <div style='margin-top: 15px; padding: 10px; background-color: #f8f9fa; border-radius: 5px; font-size: 14px;'>
263
+ <strong>Legend:</strong><br>
264
+ <span style='background-color: #d4edda; color: #155724; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✓</span> Correct punctuation &nbsp;
265
+ <span style='background-color: #f8d7da; color: #721c24; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✗</span> Wrong/Missing/Extra punctuation &nbsp;
266
+ <span style='background-color: #fff3cd; color: #856404; padding: 1px 3px; border-radius: 2px; margin: 2px;'>~</span> Word difference &nbsp;
267
+ <span style='padding: 1px 3px; margin: 2px;'>◦</span> Correct word<br>
268
+ <strong>∅</strong> = Empty/Missing
269
+ </div>
270
+ """
271
+
272
+ return html
273
+
274
+ def process_punctuation_restoration(input_text, ground_truth=""):
275
+ """Main processing function"""
276
+ if not input_text.strip():
277
+ return "Please enter input text", "", "", None, ""
278
+
279
+ # Make API call (using dummy for demonstration)
280
+ try:
281
+ # Try real API first
282
+ predicted_text, api_time = restore_punctuation(input_text)
283
+ if "Error" in str(predicted_text):
284
+ # Fall back to dummy API
285
+ # predicted_text, api_time = dummy_restore_punctuation(input_text)
286
+ predicted_text, api_time = f"Error : {input_text}", 999999
287
+ except:
288
+ # Fall back to dummy API
289
+ # predicted_text, api_time = dummy_restore_punctuation(input_text)
290
+ predicted_text, api_time = f"Error : {input_text}", 999999
291
+
292
+ time_info = f"API call completed in {api_time:.3f} seconds"
293
+
294
+ predicted_text = predicted_text[0] if isinstance(predicted_text, list) else predicted_text
295
+
296
+ print(f"input_text: {input_text}", flush=True)
297
+ print(f"predicted_text: {predicted_text}", flush=True)
298
+ if not ground_truth.strip():
299
+ return predicted_text, "", time_info, None, ""
300
+
301
+ # Normalize ground truth
302
+ ground_truth_normalized = clean_and_normalize_text(ground_truth)
303
+
304
+ # Compare texts
305
+ comparison_result, correct_puncs, wrong_puncs, gt_punc_counts = compare_texts(
306
+ ground_truth_normalized, predicted_text
307
+ )
308
+
309
+ # Create comparison HTML
310
+ comparison_html = format_comparison_html(comparison_result)
311
+
312
+ # Create evaluation table
313
+ eval_table = create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts)
314
+
315
+ return predicted_text, comparison_html, time_info, eval_table, f"Normalized Ground Truth: {ground_truth_normalized}"
316
+
317
+ # Create Gradio interface
318
+ def create_interface():
319
+ with gr.Blocks(title="Punctuation Restoration Evaluator", theme=gr.themes.Soft()) as app:
320
+ gr.Markdown("# 🔤 Punctuation Restoration Evaluator")
321
+ gr.Markdown("Enter text to restore punctuation. Optionally provide ground truth for evaluation.")
322
+
323
+ with gr.Row():
324
+ with gr.Column(scale=1):
325
+ input_text = gr.Textbox(
326
+ label="Input Text (without punctuation)",
327
+ placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত",
328
+ lines=4
329
+ )
330
+
331
+ ground_truth = gr.Textbox(
332
+ label="Ground Truth (optional)",
333
+ placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?",
334
+ lines=4
335
+ )
336
+
337
+ submit_btn = gr.Button("🚀 Restore Punctuation", variant="primary")
338
+
339
+ with gr.Column(scale=2):
340
+ api_time = gr.Textbox(label="⏱️ API Response Time", interactive=False)
341
+
342
+ predicted_output = gr.Textbox(
343
+ label="📝 Predicted Output",
344
+ lines=3,
345
+ interactive=False
346
+ )
347
+
348
+ normalized_gt = gr.Textbox(
349
+ label="📋 Normalized Ground Truth",
350
+ lines=2,
351
+ interactive=False
352
+ )
353
+
354
+ comparison_output = gr.HTML(
355
+ label="🔍 Token-wise Comparison",
356
+ value="<p>Comparison will appear here after processing with ground truth.</p>"
357
+ )
358
+
359
+ evaluation_table = gr.Dataframe(
360
+ label="📊 Punctuation Evaluation Metrics",
361
+ headers=["Punctuation Name", "Correctly Classified", "Wrongly Classified", "Count in Ground Truth"],
362
+ interactive=False
363
+ )
364
+
365
+ # Legend
366
+ gr.Markdown("""
367
+ ### 🎨 Color Legend:
368
+ - 🟢 **Green**: Correctly predicted punctuation
369
+ - 🔴 **Red**: Incorrectly predicted, missing, or extra punctuation/word
370
+ - 🟡 **Orange**: Word-level differences
371
+ - ⚫ **Black**: Correct words/tokens
372
+ - **∅**: Empty/Missing (instead of showing word→word or punct→word)
373
+ """)
374
+
375
+ submit_btn.click(
376
+ fn=process_punctuation_restoration,
377
+ inputs=[input_text, ground_truth],
378
+ outputs=[predicted_output, comparison_output, api_time, evaluation_table, normalized_gt]
379
+ )
380
+
381
+ # Example section
382
+ gr.Markdown("### 📚 Example")
383
+ gr.Examples(
384
+ examples=[
385
+ [
386
+ "পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত",
387
+ "পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?"
388
+ ],
389
+ [
390
+ "ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান",
391
+ ""
392
+ ]
393
+ ],
394
+ inputs=[input_text, ground_truth]
395
+ )
396
+
397
+ return app
398
+
399
+ if __name__ == "__main__":
400
+ app = create_interface()
401
+ app.launch(
402
+ server_name="0.0.0.0",
403
+ server_port=7860,
404
+ share=False,
405
+ debug=True
406
+ )
config.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import *
2
+
3
+ # special tokens indices in different models available in transformers
4
+ TOKEN_IDX = {
5
+ 'bert': {
6
+ 'START_SEQ': 101,
7
+ 'PAD': 0,
8
+ 'END_SEQ': 102,
9
+ 'UNK': 100
10
+ },
11
+ 'xlm': {
12
+ 'START_SEQ': 0,
13
+ 'PAD': 2,
14
+ 'END_SEQ': 1,
15
+ 'UNK': 3
16
+ },
17
+ 'roberta': {
18
+ 'START_SEQ': 0,
19
+ 'PAD': 1,
20
+ 'END_SEQ': 2,
21
+ 'UNK': 3
22
+ },
23
+ 'albert': {
24
+ 'START_SEQ': 2,
25
+ 'PAD': 0,
26
+ 'END_SEQ': 3,
27
+ 'UNK': 1
28
+ },
29
+ }
30
+
31
+ # 'O' -> No punctuation
32
+ punctuation_dict = {
33
+ '0': 0,
34
+ "DARI": 1,
35
+ "COMMA": 2,
36
+ "SEMICOLON": 3,
37
+ "QUESTION": 4,
38
+ "EXCLAMATION": 5,
39
+ "COLON": 6,
40
+ "HYPHEN": 7,
41
+ }
42
+
43
+ punctuation_map = {
44
+ 0: "",
45
+ 1: '।', # 'DARI'
46
+ 2: ',', # 'COMMA'
47
+ 3: ';', # 'SEMICOLON'
48
+ 4: '?', # 'QUESTION'
49
+ 5: '!', # 'EXCLAMATION'
50
+ 6: ':', # 'COLON'
51
+ 7: '-', # 'HYPHEN'
52
+ }
53
+
54
+ # pretrained model name: (model class, model tokenizer, output dimension, token style)
55
+ MODELS = {
56
+ 'bert-base-uncased': (BertModel, BertTokenizer, 768, 'bert'),
57
+ 'bert-large-uncased': (BertModel, BertTokenizer, 1024, 'bert'),
58
+ 'bert-base-multilingual-cased': (BertModel, BertTokenizer, 768, 'bert'),
59
+ 'bert-base-multilingual-uncased': (BertModel, BertTokenizer, 768, 'bert'),
60
+ 'sagorsarker/bangla-bert-base': (BertModel, BertTokenizer, 768, 'bert'),
61
+ # 'distilbert-base-multilingual-cased': (AutoModelForMaskedLM, AutoTokenizer, 768, 'bert'),
62
+ 'xlm-mlm-en-2048': (XLMModel, XLMTokenizer, 2048, 'xlm'),
63
+ 'xlm-mlm-100-1280': (XLMModel, XLMTokenizer, 1280, 'xlm'),
64
+ 'roberta-base': (RobertaModel, RobertaTokenizer, 768, 'roberta'),
65
+ 'roberta-large': (RobertaModel, RobertaTokenizer, 1024, 'roberta'),
66
+ 'neuralspace-reverie/indic-transformers-bn-roberta': (RobertaModel, RobertaTokenizer, 768, 'roberta'),
67
+ 'distilbert-base-uncased': (DistilBertModel, DistilBertTokenizer, 768, 'bert'),
68
+ 'distilbert-base-multilingual-cased': (DistilBertModel, DistilBertTokenizer, 768, 'bert'),
69
+ './distilbert-base-multilingual-cased': (DistilBertModel, DistilBertTokenizer, 768, 'bert'),
70
+ 'xlm-roberta-base': (XLMRobertaModel, XLMRobertaTokenizer, 768, 'roberta'),
71
+ 'xlm-roberta-large': (XLMRobertaModel, XLMRobertaTokenizer, 1024, 'roberta'),
72
+ 'albert-base-v1': (AlbertModel, AlbertTokenizer, 768, 'albert'),
73
+ 'albert-base-v2': (AlbertModel, AlbertTokenizer, 768, 'albert'),
74
+ 'albert-large-v2': (AlbertModel, AlbertTokenizer, 1024, 'albert'),
75
+ }
distilbert-base-multilingual-cased/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForMaskedLM"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "dim": 768,
8
+ "dropout": 0.1,
9
+ "hidden_dim": 3072,
10
+ "initializer_range": 0.02,
11
+ "max_position_embeddings": 512,
12
+ "model_type": "distilbert",
13
+ "n_heads": 12,
14
+ "n_layers": 6,
15
+ "output_past": true,
16
+ "pad_token_id": 0,
17
+ "qa_dropout": 0.1,
18
+ "seq_classif_dropout": 0.2,
19
+ "sinusoidal_pos_embds": false,
20
+ "tie_weights_": true,
21
+ "vocab_size": 119547
22
+ }
distilbert-base-multilingual-cased/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
distilbert-base-multilingual-cased/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": false, "model_max_length": 512}
distilbert-base-multilingual-cased/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
entrypoint.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Start the API in background
4
+ python api_onnx.py &
5
+
6
+ # Wait briefly to make sure API is up
7
+ sleep 5
8
+
9
+ # Start the Gradio UI (on port 5685)
10
+ python app.py
inference_onnx.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import List, Union, Dict, Any
4
+ from config import *
5
+
6
+ def get_encoded_input_single(text, tokenizer, token_style, sequence_len = 256):
7
+ """Process a single text sequence - matches your conversion code logic"""
8
+ words = text.split()
9
+ word_pos = 0
10
+
11
+ x = [TOKEN_IDX[token_style]['START_SEQ']]
12
+ y_mask = [0]
13
+
14
+ while len(x) < sequence_len and word_pos < len(words):
15
+ tokens = tokenizer.tokenize(words[word_pos])
16
+ if len(tokens) + len(x) >= sequence_len:
17
+ break
18
+ else:
19
+ for i in range(len(tokens) - 1):
20
+ x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
21
+ y_mask.append(0)
22
+ x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
23
+ y_mask.append(1)
24
+ word_pos += 1
25
+
26
+ x.append(TOKEN_IDX[token_style]['END_SEQ'])
27
+ y_mask.append(0)
28
+
29
+ # Pad to sequence_len
30
+ if len(x) < sequence_len:
31
+ x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
32
+ y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
33
+
34
+ attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]
35
+
36
+ return {
37
+ 'input_values': x,
38
+ 'attention_mask': attn_mask,
39
+ 'y_mask': y_mask
40
+ }
41
+
42
+ def get_encoded_input_batch(texts, tokenizer, token_style, sequence_len = 256):
43
+ """Process a batch of text sequences - matches your conversion code logic"""
44
+ batch_data = []
45
+
46
+ for text in texts:
47
+ encoded = get_encoded_input_single(text, tokenizer, token_style, sequence_len)
48
+ batch_data.append(encoded)
49
+
50
+ # Stack all sequences into batch tensors
51
+ batch_input_values = torch.tensor([item['input_values'] for item in batch_data])
52
+ batch_attention_mask = torch.tensor([item['attention_mask'] for item in batch_data])
53
+ batch_y_mask = torch.tensor([item['y_mask'] for item in batch_data])
54
+
55
+ encoded_input = {
56
+ 'input_values': batch_input_values,
57
+ 'attention_mask': batch_attention_mask,
58
+ 'y_mask': batch_y_mask
59
+ }
60
+
61
+ return encoded_input
62
+
63
+ def run_onnx_inference(input_values, attention_mask, session):
64
+ """Run ONNX inference with the unified model"""
65
+ # Get input/output names
66
+ input_values_name = session.get_inputs()[0].name
67
+ attention_mask_name = session.get_inputs()[1].name
68
+ output_name = session.get_outputs()[0].name
69
+
70
+ # Prepare inputs for ONNX (convert to numpy)
71
+ inputs = {
72
+ input_values_name: input_values.cpu().numpy(),
73
+ attention_mask_name: attention_mask.cpu().numpy()
74
+ }
75
+
76
+ # Run inference
77
+ output = session.run([output_name], inputs)
78
+ predictions = torch.tensor(output[0]) # Shape: [batch_size, seq_len, num_classes]
79
+ predictions = torch.argmax(predictions, dim=2) # Shape: [batch_size, seq_len]
80
+
81
+ return predictions
82
+
83
+ def get_transcription_batch(texts, session, tokenizer, device, token_style):
84
+ """Process multiple texts and return punctuated results"""
85
+
86
+ # Prepare batch data
87
+ encoded_batch = get_encoded_input_batch(texts, tokenizer, token_style)
88
+
89
+ # Move to device
90
+ input_values = encoded_batch['input_values'].to(device)
91
+ attention_mask = encoded_batch['attention_mask'].to(device)
92
+ y_masks = encoded_batch['y_mask']
93
+
94
+ # Run batch inference
95
+ predictions = run_onnx_inference(input_values, attention_mask, session)
96
+
97
+ # Post-process results for each text
98
+ results = []
99
+ for text_idx, text in enumerate(texts):
100
+ words_original_case = text.split()
101
+ y_mask = y_masks[text_idx]
102
+ y_predict = predictions[text_idx]
103
+
104
+ result = ""
105
+ decode_idx = 0
106
+
107
+ for i in range(y_mask.shape[0]):
108
+ if y_mask[i] == 1 and decode_idx < len(words_original_case):
109
+ result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' '
110
+ decode_idx += 1
111
+
112
+ results.append(result.strip())
113
+
114
+ return results
115
+
116
+ def get_transcription(text_or_texts, session, tokenizer, device, token_style):
117
+ """
118
+ Main function that handles both single text and batch processing
119
+ Uses the unified ONNX model for both cases
120
+
121
+ Args:
122
+ text_or_texts: Single text string or list of text strings
123
+
124
+ Returns:
125
+ Single punctuated string or list of punctuated strings
126
+ """
127
+ if isinstance(text_or_texts, str):
128
+ return get_transcription_batch([text_or_texts], session, tokenizer, device, token_style)
129
+ elif isinstance(text_or_texts, list):
130
+ return get_transcription_batch(text_or_texts, session, tokenizer, device, token_style)
131
+ else:
132
+ raise ValueError("Input must be either a string or a list of strings")
133
+
134
+
135
+ if __name__ == '__main__':
136
+ import time
137
+
138
+ test_text = 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালো���াবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট'
139
+
140
+ print("Testing single text processing:")
141
+ print("=" * 50)
142
+
143
+ # Test single text processing
144
+ for i in range(3):
145
+ start_time = time.time()
146
+ result = get_transcription(test_text)
147
+ end_time = time.time()
148
+ print(f"Run {i+1}: {end_time - start_time:.4f}s")
149
+
150
+ print(f"\nSingle result: {result[:100]}...")
151
+
152
+ print("\nTesting batch text processing:")
153
+ print("=" * 50)
154
+
155
+ # Test batch processing
156
+ batch_texts = [
157
+ 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট',
158
+ 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট',
159
+ ]
160
+
161
+ start_time = time.time()
162
+ batch_results = get_transcription(batch_texts)
163
+ end_time = time.time()
164
+
165
+ print(f"Batch processing time: {end_time - start_time:.4f}s")
166
+ print(f"Processed {len(batch_texts)} texts")
167
+ print(f"Average time per text: {(end_time - start_time) / len(batch_texts):.4f}s")
168
+
169
+ for i, result in enumerate(batch_results):
170
+ print(f"Text {i+1}: {result[:50]}...")
171
+
172
+
poc_onnx_model_punctuation_batch.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72f36708c26dee2494269930d59e64e09f142ee2749082806b6fc5fb6d13e511
3
+ size 576918507
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.20.1
2
+ gradio
3
+ requests
4
+ pandas
5
+ fastapi
6
+ uvicorn
7
+ onnxruntime-gpu
8
+ numpy
9
+ sacremoses==0.1.1