Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
import torch | |
def loadSqueeze(): | |
tokenizer = AutoTokenizer.from_pretrained("ALOQAS/squeezebert-uncased-finetuned-squad-v2") | |
model = AutoModelForQuestionAnswering.from_pretrained("ALOQAS/squeezebert-uncased-finetuned-squad-v2") | |
return tokenizer, model | |
def squeezebert(context, question, model, tokenizer): | |
# Tokenize the input question-context pair | |
inputs = tokenizer.encode_plus(question, context, max_length=512, truncation=True, padding=True, return_tensors='pt') | |
# Send inputs to the same device as your model | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
# Forward pass, get model outputs | |
outputs = model(**inputs) | |
# Extract the start and end positions of the answer in the tokens | |
answer_start_scores, answer_end_scores = outputs.start_logits, outputs.end_logits | |
# Calculate probabilities from logits | |
answer_start_prob = torch.softmax(answer_start_scores, dim=-1) | |
answer_end_prob = torch.softmax(answer_end_scores, dim=-1) | |
# Find the most likely start and end positions | |
answer_start_index = torch.argmax(answer_start_prob) # Most likely start of answer | |
answer_end_index = torch.argmax(answer_end_prob) + 1 # Most likely end of answer; +1 for inclusive slicing | |
# Extract the highest probability scores | |
start_score = answer_start_prob.max().item() # Highest probability of start | |
end_score = answer_end_prob.max().item() # Highest probability of end | |
# Combine the scores into a singular score | |
combined_score = (start_score * end_score) ** 0.5 # Geometric mean of start and end scores | |
# Convert token indices to the actual answer text | |
answer_tokens = inputs['input_ids'][0, answer_start_index:answer_end_index] | |
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) | |
# Return the answer, its positions, and the combined score | |
return { | |
"answer": answer, | |
"start": answer_start_index.item(), | |
"end": answer_end_index.item(), | |
"score": combined_score | |
} | |
def bert(context, question, pip): | |
return pip(context=context, question=question) | |
def deberta(context, question, pip): | |
return pip(context=context, question=question) | |