|
import streamlit as st |
|
import numpy as np |
|
import torch |
|
from transformers import DistilBertTokenizer, DistilBertForMaskedLM |
|
from qa_model import ReuseQuestionDistilBERT |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
try: |
|
mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert |
|
m = ReuseQuestionDistilBERT(mod) |
|
m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu'))) |
|
model = m |
|
tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer') |
|
return model, tokenizer |
|
except Exception as e: |
|
st.error(f"Error loading model: {e}") |
|
return None, None |
|
|
|
def get_answer(question, text, tokenizer, model): |
|
if model is None or tokenizer is None: |
|
return "Model not loaded properly." |
|
|
|
question = [question.strip()] |
|
text = [text.strip()] |
|
|
|
inputs = tokenizer( |
|
question, |
|
text, |
|
max_length=512, |
|
truncation="only_second", |
|
padding="max_length", |
|
return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
outputs = model( |
|
inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
start_positions=None, |
|
end_positions=None |
|
) |
|
|
|
if "start_logits" not in outputs or "end_logits" not in outputs: |
|
return "Error: Model output structure is incorrect." |
|
|
|
start = torch.argmax(outputs["start_logits"], dim=1) |
|
end = torch.argmax(outputs["end_logits"], dim=1) |
|
|
|
ans_tokens = inputs["input_ids"][0, start:end + 1] |
|
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True) |
|
predicted = tokenizer.convert_tokens_to_string(answer_tokens) |
|
return predicted or "No answer found." |
|
|
|
def main(): |
|
st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:") |
|
st.write("# Question Answering Tool") |
|
|
|
model, tokenizer = load_model() |
|
|
|
with st.form("qa_form"): |
|
text = st.text_area("Enter your text here") |
|
question = st.text_input("Enter your question here") |
|
|
|
if st.form_submit_button("Submit"): |
|
if not text or not question: |
|
st.warning("Please enter both text and a question.") |
|
else: |
|
st.text("Processing...") |
|
answer = get_answer(question, text, tokenizer, model) |
|
st.text(f"Answer: {answer}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|