|
|
|
|
|
import torch |
|
import timm |
|
import torch.nn as nn |
|
|
|
|
|
amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።') |
|
|
|
char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)} |
|
char_to_idx['<UNK>'] = len(amharic_chars) + 1 |
|
idx_to_char = {idx: char for char, idx in char_to_idx.items()} |
|
idx_to_char[0] = '<blank>' |
|
|
|
vocab_size = len(char_to_idx) + 1 |
|
|
|
class ViTRecognitionModel(nn.Module): |
|
def __init__(self, vocab_size, hidden_dim=768, max_length=20): |
|
super(ViTRecognitionModel, self).__init__() |
|
self.vit = timm.create_model( |
|
'vit_base_patch16_224', |
|
pretrained=True, |
|
num_classes=0, |
|
features_only=True, |
|
out_indices=(11,) |
|
) |
|
self.hidden_dim = hidden_dim |
|
self.fc = nn.Linear(hidden_dim, vocab_size) |
|
self.log_softmax = nn.LogSoftmax(dim=2) |
|
self.max_length = max_length |
|
|
|
def forward(self, x): |
|
features = self.vit(x)[0] |
|
|
|
if features.dim() == 3: |
|
batch_size, hidden_dim, num_patches = features.shape |
|
grid_size = int(num_patches ** 0.5) |
|
if grid_size * grid_size != num_patches: |
|
raise ValueError(f"Number of patches {num_patches} is not a perfect square.") |
|
H, W = grid_size, grid_size |
|
features = features.view(batch_size, hidden_dim, H, W) |
|
elif features.dim() == 4: |
|
batch_size, hidden_dim, H, W = features.shape |
|
else: |
|
raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.") |
|
|
|
features = features.flatten(2).transpose(1, 2) |
|
logits = self.fc(features) |
|
log_probs = self.log_softmax(logits) |
|
log_probs = log_probs.transpose(0, 1) |
|
return log_probs |
|
|
|
def load_model(model_path, device='cpu'): |
|
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20) |
|
|
|
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|