Spaces:
Runtime error
Runtime error
| import torch | |
| import transformers | |
| from transformers import T5Tokenizer, T5EncoderModel, T5Config | |
| transformers.logging.set_verbosity_error() | |
| def exists(val): | |
| return val is not None | |
| # config | |
| MAX_LENGTH = 256 | |
| DEFAULT_T5_NAME = 'google/t5-v1_1-base' | |
| T5_CONFIGS = {} | |
| # singleton globals | |
| def get_tokenizer(name): | |
| tokenizer = T5Tokenizer.from_pretrained(name) | |
| return tokenizer | |
| def get_model(name): | |
| model = T5EncoderModel.from_pretrained(name) | |
| return model | |
| def get_model_and_tokenizer(name): | |
| global T5_CONFIGS | |
| if name not in T5_CONFIGS: | |
| T5_CONFIGS[name] = dict() | |
| if "model" not in T5_CONFIGS[name]: | |
| T5_CONFIGS[name]["model"] = get_model(name) | |
| if "tokenizer" not in T5_CONFIGS[name]: | |
| T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) | |
| return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] | |
| def get_encoded_dim(name): | |
| if name not in T5_CONFIGS: | |
| # avoids loading the model if we only want to get the dim | |
| config = T5Config.from_pretrained(name) | |
| T5_CONFIGS[name] = dict(config=config) | |
| elif "config" in T5_CONFIGS[name]: | |
| config = T5_CONFIGS[name]["config"] | |
| elif "model" in T5_CONFIGS[name]: | |
| config = T5_CONFIGS[name]["model"].config | |
| else: | |
| assert False | |
| return config.d_model | |
| class T5Encoder(torch.nn.Module): | |
| def __init__(self, name=DEFAULT_T5_NAME, max_length=MAX_LENGTH, padding='longest', masked_mean=False): | |
| super().__init__() | |
| self.name = name | |
| self.t5, self.tokenizer = get_model_and_tokenizer(name) | |
| self.max_length = max_length | |
| self.output_size = get_encoded_dim(name) | |
| self.padding = padding | |
| self.masked_mean = masked_mean | |
| def forward(self, x, return_only_pooled=True): | |
| encoded = self.tokenizer.batch_encode_plus( | |
| x, | |
| return_tensors = "pt", | |
| padding = self.padding, | |
| max_length = self.max_length, | |
| truncation = True | |
| ) | |
| device = next(self.t5.parameters()).device | |
| input_ids = encoded.input_ids.to(device) | |
| attn_mask = encoded.attention_mask.to(device).bool() | |
| output = self.t5(input_ids = input_ids, attention_mask = attn_mask) | |
| encoded_text = output.last_hidden_state.detach() | |
| # return encoded_text[:, 0] | |
| # print(input_ids) | |
| # print(attn_mask) | |
| #if self.masked_mean: | |
| pooled = masked_mean(encoded_text, dim=1, mask=attn_mask) | |
| if return_only_pooled: | |
| return pooled | |
| else: | |
| return pooled, encoded_text, attn_mask | |
| #else: | |
| # return encoded_text.mean(dim=1) | |
| from einops import rearrange | |
| def masked_mean(t, *, dim, mask = None): | |
| if not exists(mask): | |
| return t.mean(dim = dim) | |
| denom = mask.sum(dim = dim, keepdim = True) | |
| mask = rearrange(mask, 'b n -> b n 1') | |
| masked_t = t.masked_fill(~mask, 0.) | |
| return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) | |