# model.py import torch import timm import torch.nn as nn # Character-to-Index Mapping (should be the same as used during training) amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።') char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)} # Start indexing from 1 char_to_idx[''] = len(amharic_chars) + 1 # Unknown characters idx_to_char = {idx: char for char, idx in char_to_idx.items()} idx_to_char[0] = '' # CTC blank token vocab_size = len(char_to_idx) + 1 # +1 for the blank token at index 0 class ViTRecognitionModel(nn.Module): def __init__(self, vocab_size, hidden_dim=768, max_length=20): super(ViTRecognitionModel, self).__init__() self.vit = timm.create_model( 'vit_base_patch16_224', pretrained=True, num_classes=0, # Disable classification head features_only=True, # Return feature maps out_indices=(11,) # Get the last feature map ) self.hidden_dim = hidden_dim self.fc = nn.Linear(hidden_dim, vocab_size) # Map hidden_dim to vocab_size self.log_softmax = nn.LogSoftmax(dim=2) self.max_length = max_length def forward(self, x): features = self.vit(x)[0] # [batch, hidden_dim, H*W] if features.dim() == 3: batch_size, hidden_dim, num_patches = features.shape grid_size = int(num_patches ** 0.5) if grid_size * grid_size != num_patches: raise ValueError(f"Number of patches {num_patches} is not a perfect square.") H, W = grid_size, grid_size features = features.view(batch_size, hidden_dim, H, W) elif features.dim() == 4: batch_size, hidden_dim, H, W = features.shape else: raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.") features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim] logits = self.fc(features) # [batch, H*W, vocab_size] log_probs = self.log_softmax(logits) # [batch, H*W, vocab_size] log_probs = log_probs.transpose(0, 1) # [H*W, batch, vocab_size] return log_probs def load_model(model_path, device='cpu'): model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() return model