File size: 3,348 Bytes
a99a5cc 892511e 1b9e776 a99a5cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# model.py
import torch
import timm
import torch.nn as nn
# Character-to-Index Mapping (should be the same as used during training)
amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።')
char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)} # Start indexing from 1
char_to_idx['<UNK>'] = len(amharic_chars) + 1 # Unknown characters
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
idx_to_char[0] = '<blank>' # CTC blank token
vocab_size = len(char_to_idx) + 1 # +1 for the blank token at index 0
class ViTRecognitionModel(nn.Module):
def __init__(self, vocab_size, hidden_dim=768, max_length=20):
super(ViTRecognitionModel, self).__init__()
self.vit = timm.create_model(
'vit_base_patch16_224',
pretrained=True,
num_classes=0, # Disable classification head
features_only=True, # Return feature maps
out_indices=(11,) # Get the last feature map
)
self.hidden_dim = hidden_dim
self.fc = nn.Linear(hidden_dim, vocab_size) # Map hidden_dim to vocab_size
self.log_softmax = nn.LogSoftmax(dim=2)
self.max_length = max_length
def forward(self, x):
features = self.vit(x)[0] # [batch, hidden_dim, H*W]
if features.dim() == 3:
batch_size, hidden_dim, num_patches = features.shape
grid_size = int(num_patches ** 0.5)
if grid_size * grid_size != num_patches:
raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
H, W = grid_size, grid_size
features = features.view(batch_size, hidden_dim, H, W)
elif features.dim() == 4:
batch_size, hidden_dim, H, W = features.shape
else:
raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim]
logits = self.fc(features) # [batch, H*W, vocab_size]
log_probs = self.log_softmax(logits) # [batch, H*W, vocab_size]
log_probs = log_probs.transpose(0, 1) # [H*W, batch, vocab_size]
return log_probs
def load_model(model_path, device='cpu'):
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
# Set weights_only=True to address the FutureWarning
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
model.to(device)
model.eval()
return model
|