Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import AutoModel | |
class TextEncoder(nn.Module): | |
def __init__(self, output_dim=64, lang_model="sentence-transformers/all-MiniLM-L6-v2", unfreeze_n_blocks=4): | |
super().__init__() | |
self.lang_model = lang_model | |
self.encoder = AutoModel.from_pretrained(lang_model) | |
# freeze all parameters | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
# unfreeze the last few encoder layers | |
for layer in self.encoder.encoder.layer[ - unfreeze_n_blocks :]: | |
for param in layer.parameters(): | |
param.requires_grad = True | |
# unfreeze the pooler layer | |
for param in self.encoder.pooler.parameters(): | |
param.requires_grad = True | |
self.fc = nn.Linear(self.encoder.config.hidden_size, output_dim) | |
def forward(self, input_ids, attention_mask=None): | |
x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
x = self.fc(x) | |
return x |