Gizachew commited on
Commit
a99a5cc
·
verified ·
1 Parent(s): 852aa43

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +58 -0
model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ import torch
4
+ import timm
5
+ import torch.nn as nn
6
+
7
+ # Character-to-Index Mapping (should be the same as used during training)
8
+ amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።')
9
+
10
+ char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)} # Start indexing from 1
11
+ char_to_idx['<UNK>'] = len(amharic_chars) + 1 # Unknown characters
12
+ idx_to_char = {idx: char for char, idx in char_to_idx.items()}
13
+ idx_to_char[0] = '<blank>' # CTC blank token
14
+
15
+ vocab_size = len(char_to_idx) + 1 # +1 for the blank token at index 0
16
+
17
+ class ViTRecognitionModel(nn.Module):
18
+ def __init__(self, vocab_size, hidden_dim=768, max_length=20):
19
+ super(ViTRecognitionModel, self).__init__()
20
+ self.vit = timm.create_model(
21
+ 'vit_base_patch16_224',
22
+ pretrained=True,
23
+ num_classes=0, # Disable classification head
24
+ features_only=True, # Return feature maps
25
+ out_indices=(11,) # Get the last feature map
26
+ )
27
+ self.hidden_dim = hidden_dim
28
+ self.fc = nn.Linear(hidden_dim, vocab_size) # Map hidden_dim to vocab_size
29
+ self.log_softmax = nn.LogSoftmax(dim=2)
30
+ self.max_length = max_length
31
+
32
+ def forward(self, x):
33
+ features = self.vit(x)[0] # [batch, hidden_dim, H*W]
34
+
35
+ if features.dim() == 3:
36
+ batch_size, hidden_dim, num_patches = features.shape
37
+ grid_size = int(num_patches ** 0.5)
38
+ if grid_size * grid_size != num_patches:
39
+ raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
40
+ H, W = grid_size, grid_size
41
+ features = features.view(batch_size, hidden_dim, H, W)
42
+ elif features.dim() == 4:
43
+ batch_size, hidden_dim, H, W = features.shape
44
+ else:
45
+ raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
46
+
47
+ features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim]
48
+ logits = self.fc(features) # [batch, H*W, vocab_size]
49
+ log_probs = self.log_softmax(logits) # [batch, H*W, vocab_size]
50
+ log_probs = log_probs.transpose(0, 1) # [H*W, batch, vocab_size]
51
+ return log_probs
52
+
53
+ def load_model(model_path, device='cpu'):
54
+ model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
55
+ model.load_state_dict(torch.load(model_path, map_location=device))
56
+ model.to(device)
57
+ model.eval()
58
+ return model