Question Answering
sanjudebnath commited on
Commit
ceee800
verified
1 Parent(s): bc565b0

Update application.py

Browse files
Files changed (1) hide show
  1. application.py +40 -36
application.py CHANGED
@@ -2,22 +2,25 @@ import streamlit as st
2
  import numpy as np
3
  import torch
4
  from transformers import DistilBertTokenizer, DistilBertForMaskedLM
5
-
6
  from qa_model import ReuseQuestionDistilBERT
7
 
8
- @st.cache(allow_output_mutation=True)
9
  def load_model():
10
- mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
11
- m = ReuseQuestionDistilBERT(mod)
12
- m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
13
- model = m
14
- del mod
15
- del m
16
- tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
17
- return model, tokenizer
18
-
 
19
 
20
  def get_answer(question, text, tokenizer, model):
 
 
 
21
  question = [question.strip()]
22
  text = [text.strip()]
23
 
@@ -27,44 +30,45 @@ def get_answer(question, text, tokenizer, model):
27
  max_length=512,
28
  truncation="only_second",
29
  padding="max_length",
 
30
  )
31
- input_ids = torch.tensor(inputs['input_ids'])
32
- outputs = model(input_ids, attention_mask=torch.tensor(inputs['attention_mask']), start_positions=None, end_positions=None)
33
 
34
- start = torch.argmax(outputs['start_logits'])
35
- end = torch.argmax(outputs['end_logits'])
 
 
 
 
 
36
 
37
- ans_tokens = input_ids[0][start: end + 1]
 
38
 
 
 
 
 
39
  answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
40
  predicted = tokenizer.convert_tokens_to_string(answer_tokens)
41
- return predicted
42
-
43
 
44
  def main():
45
  st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
46
-
47
- st.write("# Question Answering Tool \n"
48
- "This tool will help you find answers to your questions about the text you provide. \n"
49
- "Please enter your question and the text you want to search in the boxes below.")
50
  model, tokenizer = load_model()
51
-
52
  with st.form("qa_form"):
53
- # define a streamlit textarea
54
- text = st.text_area("Enter your text here", on_change=None)
55
-
56
- # define a streamlit input
57
  question = st.text_input("Enter your question here")
58
 
59
  if st.form_submit_button("Submit"):
60
- data_load_state = st.text('Let me think about that...')
61
- # call the function to get the answer
62
- answer = get_answer(question, text, tokenizer, model)
63
- # display the answer
64
- if answer == "":
65
- data_load_state.text("Sorry but I don't know the answer to that question")
66
  else:
67
- data_load_state.text(answer)
68
-
 
69
 
70
- main()
 
 
2
  import numpy as np
3
  import torch
4
  from transformers import DistilBertTokenizer, DistilBertForMaskedLM
 
5
  from qa_model import ReuseQuestionDistilBERT
6
 
7
+ @st.cache_resource
8
  def load_model():
9
+ try:
10
+ mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
11
+ m = ReuseQuestionDistilBERT(mod)
12
+ m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
13
+ model = m
14
+ tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
15
+ return model, tokenizer
16
+ except Exception as e:
17
+ st.error(f"Error loading model: {e}")
18
+ return None, None
19
 
20
  def get_answer(question, text, tokenizer, model):
21
+ if model is None or tokenizer is None:
22
+ return "Model not loaded properly."
23
+
24
  question = [question.strip()]
25
  text = [text.strip()]
26
 
 
30
  max_length=512,
31
  truncation="only_second",
32
  padding="max_length",
33
+ return_tensors="pt"
34
  )
 
 
35
 
36
+ with torch.no_grad():
37
+ outputs = model(
38
+ inputs["input_ids"],
39
+ attention_mask=inputs["attention_mask"],
40
+ start_positions=None,
41
+ end_positions=None
42
+ )
43
 
44
+ if "start_logits" not in outputs or "end_logits" not in outputs:
45
+ return "Error: Model output structure is incorrect."
46
 
47
+ start = torch.argmax(outputs["start_logits"], dim=1)
48
+ end = torch.argmax(outputs["end_logits"], dim=1)
49
+
50
+ ans_tokens = inputs["input_ids"][0, start:end + 1]
51
  answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
52
  predicted = tokenizer.convert_tokens_to_string(answer_tokens)
53
+ return predicted or "No answer found."
 
54
 
55
  def main():
56
  st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
57
+ st.write("# Question Answering Tool")
58
+
 
 
59
  model, tokenizer = load_model()
60
+
61
  with st.form("qa_form"):
62
+ text = st.text_area("Enter your text here")
 
 
 
63
  question = st.text_input("Enter your question here")
64
 
65
  if st.form_submit_button("Submit"):
66
+ if not text or not question:
67
+ st.warning("Please enter both text and a question.")
 
 
 
 
68
  else:
69
+ st.text("Processing...")
70
+ answer = get_answer(question, text, tokenizer, model)
71
+ st.text(f"Answer: {answer}")
72
 
73
+ if __name__ == "__main__":
74
+ main()