File size: 3,348 Bytes
a99a5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892511e
1b9e776
a99a5cc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# model.py

import torch
import timm
import torch.nn as nn

# Character-to-Index Mapping (should be the same as used during training)
amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።')

char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)}  # Start indexing from 1
char_to_idx['<UNK>'] = len(amharic_chars) + 1  # Unknown characters
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
idx_to_char[0] = '<blank>'  # CTC blank token

vocab_size = len(char_to_idx) + 1  # +1 for the blank token at index 0

class ViTRecognitionModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim=768, max_length=20):
        super(ViTRecognitionModel, self).__init__()
        self.vit = timm.create_model(
            'vit_base_patch16_224',
            pretrained=True,
            num_classes=0,               # Disable classification head
            features_only=True,          # Return feature maps
            out_indices=(11,)            # Get the last feature map
        )
        self.hidden_dim = hidden_dim
        self.fc = nn.Linear(hidden_dim, vocab_size)  # Map hidden_dim to vocab_size
        self.log_softmax = nn.LogSoftmax(dim=2)
        self.max_length = max_length

    def forward(self, x):
        features = self.vit(x)[0]  # [batch, hidden_dim, H*W]
        
        if features.dim() == 3:
            batch_size, hidden_dim, num_patches = features.shape
            grid_size = int(num_patches ** 0.5)
            if grid_size * grid_size != num_patches:
                raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
            H, W = grid_size, grid_size
            features = features.view(batch_size, hidden_dim, H, W)
        elif features.dim() == 4:
            batch_size, hidden_dim, H, W = features.shape
        else:
            raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
        
        features = features.flatten(2).transpose(1, 2)  # [batch, H*W, hidden_dim]
        logits = self.fc(features)  # [batch, H*W, vocab_size]
        log_probs = self.log_softmax(logits)  # [batch, H*W, vocab_size]
        log_probs = log_probs.transpose(0, 1)  # [H*W, batch, vocab_size]
        return log_probs

def load_model(model_path, device='cpu'):
    model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
    # Set weights_only=True to address the FutureWarning
    model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    model.to(device)
    model.eval()
    return model