vatrpp / models /unifont_module.py
vittoriopippi
Change imports
af10767
import math
import os
import cv2
from huggingface_hub import hf_hub_download
import torch
import pickle
import numpy as np
def gauss(x, sigma=1.0):
return (1.0 / math.sqrt(2.0 * math.pi) * sigma) * math.exp(-x**2 / (2.0 * sigma**2))
class UnifontModule(torch.nn.Module):
def __init__(self, out_dim, alphabet, device='cuda', input_type='unifont', projection='linear'):
super(UnifontModule, self).__init__()
self.projection_type = projection
self.device = device
self.alphabet = alphabet
self.symbols = self.get_symbols('unifont')
self.symbols_repr = self.get_symbols(input_type)
if projection == 'linear':
self.linear = torch.nn.Linear(self.symbols_repr.shape[1], out_dim)
else:
self.linear = torch.nn.Identity()
def get_symbols(self, input_type):
file_path = hf_hub_download(
repo_id="blowing-up-groundhogs/vatrpp",
filename=f"files/{input_type}.pickle",
cache_dir="./hf_cache" # Opzionale: specifica una cartella di cache
)
with open(file_path, "rb") as f:
symbols = pickle.load(f)
all_symbols = {sym['idx'][0]: sym['mat'].astype(np.float32) for sym in symbols}
symbols = []
for char in self.alphabet:
im = all_symbols[ord(char)]
im = im.flatten()
symbols.append(im)
symbols.insert(0, np.zeros_like(symbols[0]))
symbols = np.stack(symbols)
return torch.from_numpy(symbols).float().to(self.device)
def forward(self, QR):
if self.projection_type != 'cnn':
return self.linear(self.symbols_repr[QR])
else:
result = []
symbols = self.symbols_repr[QR]
for b in range(QR.size(0)):
result.append(self.linear(torch.unsqueeze(symbols[b], dim=1)))
return torch.stack(result)
class LearnableModule(torch.nn.Module):
def __init__(self, out_dim, device='cuda'):
super(LearnableModule, self).__init__()
self.device = device
self.param = torch.nn.Parameter(torch.zeros(1, 1, 256, device=device))
self.linear = torch.nn.Linear(256, out_dim)
def forward(self, QR):
return self.linear(self.param).repeat((QR.shape[0], 1, 1))
if __name__ == "__main__":
module = UnifontModule(512, "bluuuuurp", 'cpu', projection='cnn')