Spaces:
Runtime error
Runtime error
File size: 2,585 Bytes
5e8ceb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import sys
import json
import transformers
import torch
def vectorize_with_pretrained_embeddings(sentences):
"""
Produces a tensor containing a BERT embedding for each sentence in the dataset or in a
batch
Args:
sentences: List of sentences of length n
Returns:
embeddings: A 2D torch array containing embeddings for each of the n sentences (n x d)
where d = 768
"""
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
pretrained_model = transformers.BertModel.from_pretrained(
'bert-base-cased', output_hidden_states=False)
pretrained_model.eval()
embeddings = []
for sentence in sentences:
with_tags = "[CLS] " + sentence + " [SEP]"
tokenized_sentence = tokenizer.tokenize(with_tags)
tokenized_sentence = tokenized_sentence[:512]
# print(tokenized_sentence)
# print(len(tokenized_sentence))
indices_from_tokens = tokenizer.convert_tokens_to_ids(
tokenized_sentence)
segments_ids = [1] * len(indices_from_tokens)
tokens_tensor = torch.tensor([indices_from_tokens])
segments_tensors = torch.tensor([segments_ids])
# print(indices_from_tokens)
# print(tokens_tensor)
# print(segments_tensors)
with torch.no_grad():
outputs = pretrained_model(tokens_tensor, segments_tensors)[
0] # The output is the
# last hidden state of the pretrained model of shape 1 x sentence_length x BERT embedding_length
# we average across the embedding length
embeddings.append(torch.mean(outputs, dim=1))
# dimension to produce constant sized tensors
# print(embeddings[0].shape)
embeddings = torch.cat(embeddings, dim=0)
# print('Shape of embeddings tensor (n x d = 768): ', embeddings.shape)
return embeddings.cpu().detach().numpy()
def main():
# Step 1: Read JSON input from stdin
input_json = sys.stdin.read()
inputs = json.loads(input_json)
# Step 2: Extract inputs
passage = inputs.get("Passage", "")
question = inputs.get("QuestionText", "")
distractors = inputs.get("Distractors", "")
# Combine inputs
combined_input = [f"{question}\n{distractors}\n{passage}"]
# print(combined_input)
embedding = vectorize_with_pretrained_embeddings(combined_input)
embedding_flat = embedding.flatten() # Flatten to a 1D array
embedding_str = ",".join(map(str, embedding_flat))
print(embedding_str)
if __name__ == "__main__":
main()
|