Question Answering
Numini / application.py
sanjudebnath's picture
Update application.py
ceee800 verified
import streamlit as st
import numpy as np
import torch
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
from qa_model import ReuseQuestionDistilBERT
@st.cache_resource
def load_model():
try:
mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
m = ReuseQuestionDistilBERT(mod)
m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
model = m
tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
def get_answer(question, text, tokenizer, model):
if model is None or tokenizer is None:
return "Model not loaded properly."
question = [question.strip()]
text = [text.strip()]
inputs = tokenizer(
question,
text,
max_length=512,
truncation="only_second",
padding="max_length",
return_tensors="pt"
)
with torch.no_grad():
outputs = model(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
start_positions=None,
end_positions=None
)
if "start_logits" not in outputs or "end_logits" not in outputs:
return "Error: Model output structure is incorrect."
start = torch.argmax(outputs["start_logits"], dim=1)
end = torch.argmax(outputs["end_logits"], dim=1)
ans_tokens = inputs["input_ids"][0, start:end + 1]
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
predicted = tokenizer.convert_tokens_to_string(answer_tokens)
return predicted or "No answer found."
def main():
st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
st.write("# Question Answering Tool")
model, tokenizer = load_model()
with st.form("qa_form"):
text = st.text_area("Enter your text here")
question = st.text_input("Enter your question here")
if st.form_submit_button("Submit"):
if not text or not question:
st.warning("Please enter both text and a question.")
else:
st.text("Processing...")
answer = get_answer(question, text, tokenizer, model)
st.text(f"Answer: {answer}")
if __name__ == "__main__":
main()