sonIA / utils.py
Davide Fiocco
Guess less boldly
ad2c89a
raw
history blame
1.39 kB
import datetime
import json
import streamlit as st
import tokenizers
import torch
from transformers import Pipeline, pipeline
def get_answer(input, context, engine, threshold=0.5):
answer = engine({"question": input, "context": context})
if answer["score"] > threshold:
return answer["answer"]
else:
return "Non lo so, prova con un'altra domanda!"
@st.cache
def get_context():
BIRTHYEAR = 1952
OTHERBIRTHYEAR = 1984
now = datetime.datetime.now()
with open("context.json") as f:
context = (
json.load(f)["info"]
.replace("[YEAR]", str(now.year))
.replace("[TODAY]", f"{datetime.datetime.now():%d-%m-%Y}")
.replace("[BIRTHYEAR]", str(BIRTHYEAR))
.replace("[AGE]", str(now.year - BIRTHYEAR))
.replace("[OTHERAGE]", str(now.year - OTHERBIRTHYEAR))
)
return context
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None,
},
allow_output_mutation=True,
show_spinner=False,
)
def load_engine() -> Pipeline:
nlp_qa = pipeline(
"question-answering",
model="mrm8488/bert-italian-finedtuned-squadv1-it-alfa",
tokenizer="mrm8488/bert-italian-finedtuned-squadv1-it-alfa",
)
return nlp_qa