Spaces:
Runtime error
Runtime error
import sys | |
import json | |
import transformers | |
import torch | |
def vectorize_with_pretrained_embeddings(sentences): | |
""" | |
Produces a tensor containing a BERT embedding for each sentence in the dataset or in a | |
batch | |
Args: | |
sentences: List of sentences of length n | |
Returns: | |
embeddings: A 2D torch array containing embeddings for each of the n sentences (n x d) | |
where d = 768 | |
""" | |
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased') | |
pretrained_model = transformers.BertModel.from_pretrained( | |
'bert-base-cased', output_hidden_states=False) | |
pretrained_model.eval() | |
embeddings = [] | |
for sentence in sentences: | |
with_tags = "[CLS] " + sentence + " [SEP]" | |
tokenized_sentence = tokenizer.tokenize(with_tags) | |
tokenized_sentence = tokenized_sentence[:512] | |
# print(tokenized_sentence) | |
# print(len(tokenized_sentence)) | |
indices_from_tokens = tokenizer.convert_tokens_to_ids( | |
tokenized_sentence) | |
segments_ids = [1] * len(indices_from_tokens) | |
tokens_tensor = torch.tensor([indices_from_tokens]) | |
segments_tensors = torch.tensor([segments_ids]) | |
# print(indices_from_tokens) | |
# print(tokens_tensor) | |
# print(segments_tensors) | |
with torch.no_grad(): | |
outputs = pretrained_model(tokens_tensor, segments_tensors)[ | |
0] # The output is the | |
# last hidden state of the pretrained model of shape 1 x sentence_length x BERT embedding_length | |
# we average across the embedding length | |
embeddings.append(torch.mean(outputs, dim=1)) | |
# dimension to produce constant sized tensors | |
# print(embeddings[0].shape) | |
embeddings = torch.cat(embeddings, dim=0) | |
# print('Shape of embeddings tensor (n x d = 768): ', embeddings.shape) | |
return embeddings.cpu().detach().numpy() | |
def main(): | |
# Step 1: Read JSON input from stdin | |
input_json = sys.stdin.read() | |
inputs = json.loads(input_json) | |
# Step 2: Extract inputs | |
passage = inputs.get("Passage", "") | |
question = inputs.get("QuestionText", "") | |
distractors = inputs.get("Distractors", "") | |
# Combine inputs | |
combined_input = [f"{question}\n{distractors}\n{passage}"] | |
# print(combined_input) | |
embedding = vectorize_with_pretrained_embeddings(combined_input) | |
embedding_flat = embedding.flatten() # Flatten to a 1D array | |
embedding_str = ",".join(map(str, embedding_flat)) | |
print(embedding_str) | |
if __name__ == "__main__": | |
main() | |