nanoCLIP / text_encoder.py
amaralibey's picture
Create text_encoder.py
2a7a19d verified
raw
history blame
1.11 kB
import torch
import torch.nn as nn
from transformers import AutoModel
class TextEncoder(nn.Module):
def __init__(self, output_dim=64, lang_model="sentence-transformers/all-MiniLM-L6-v2", unfreeze_n_blocks=4):
super().__init__()
self.lang_model = lang_model
self.encoder = AutoModel.from_pretrained(lang_model)
# freeze all parameters
for param in self.encoder.parameters():
param.requires_grad = False
# unfreeze the last few encoder layers
for layer in self.encoder.encoder.layer[ - unfreeze_n_blocks :]:
for param in layer.parameters():
param.requires_grad = True
# unfreeze the pooler layer
for param in self.encoder.pooler.parameters():
param.requires_grad = True
self.fc = nn.Linear(self.encoder.config.hidden_size, output_dim)
def forward(self, input_ids, attention_mask=None):
x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
x = self.fc(x)
return x