Gizachew commited on
Commit
ae0e027
·
verified ·
1 Parent(s): 67c67e8

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +45 -0
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+
3
+ import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ # Character-to-Index Mapping (should be the same as used during training)
8
+ amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።')
9
+
10
+ char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)} # Start indexing from 1
11
+ char_to_idx['<UNK>'] = len(amharic_chars) + 1 # Unknown characters
12
+ idx_to_char = {idx: char for char, idx in char_to_idx.items()}
13
+ idx_to_char[0] = '<blank>' # CTC blank token
14
+
15
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
16
+ """
17
+ Preprocess the input image for the model.
18
+ """
19
+ transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
23
+ std=[0.229, 0.224, 0.225]),
24
+ ])
25
+ return transform(image)
26
+
27
+ def decode_predictions(preds: torch.Tensor) -> str:
28
+ """
29
+ Decode the model's predictions using Best Path Decoding.
30
+ """
31
+ preds = torch.argmax(preds, dim=2).transpose(0, 1) # [batch_size, H*W]
32
+ decoded_texts = []
33
+
34
+ for pred in preds:
35
+ pred = pred.cpu().numpy()
36
+ decoded = []
37
+ previous = 0 # Assuming blank index is 0
38
+ for p in pred:
39
+ if p != previous and p != 0:
40
+ decoded.append(idx_to_char.get(p, '<UNK>'))
41
+ previous = p
42
+ recognized_text = ''.join(decoded)
43
+ decoded_texts.append(recognized_text)
44
+
45
+ return decoded_texts