Divyanshu04 commited on
Commit
1d94dd5
·
verified ·
1 Parent(s): 2ec0cb8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +137 -136
src/streamlit_app.py CHANGED
@@ -7,8 +7,8 @@ from transformers import BertTokenizer, BertForSequenceClassification, AutoToken
7
  import asyncio
8
  import sys
9
 
10
- if sys.platform == "win32":
11
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
12
 
13
  # import zipfile
14
  # zip_path = "bert-base-uncased.zip"
@@ -27,139 +27,140 @@ def prepare_tokenizer_folder():
27
  for file in tokenizer_files:
28
  if os.path.exists(file):
29
  shutil.move(file, os.path.join(folder_name, file))
 
30
  # -------------------- Part 1: Prediction --------------------
31
-
32
- @st.cache_resource
33
- def load_prediction_model():
34
- prepare_tokenizer_folder()
35
- # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
36
- tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
37
- with open('label_encoder_new.pkl', 'rb') as f:
38
- label_encoder = pickle.load(f)
39
- id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)}
40
-
41
- # model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer', num_labels=len(label_encoder.classes_))
42
- model = AutoModelForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer')
43
- # model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict'])
44
- model.eval()
45
- return tokenizer, model, id_to_class
46
-
47
- tokenizer_cls, model_cls, id_to_class = load_prediction_model()
48
-
49
- def preprocess_texts(texts):
50
- return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
51
-
52
- def predict(text):
53
- inputs = preprocess_texts(text)
54
- with torch.no_grad():
55
- outputs = model_cls(**inputs)
56
- probabilities = F.softmax(outputs.logits, dim=1)
57
- top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1)
58
-
59
- top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]]
60
- top3_probs = top3_probs[0] * 100
61
- top3_probs_np = top3_probs.cpu().numpy()
62
- formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np]
63
-
64
- prediction = top3_class_names[0]
65
- probability = top3_probs_np[0]
66
-
67
- return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability
68
-
69
- # -------------------- Part 2: Follow-up Generator --------------------
70
-
71
- @st.cache_resource
72
- def load_followup_model():
73
- model_path = "Divyanshu04/Insurance_claim_followup_model" # Adjust path as needed
74
- tokenizer = AutoTokenizer.from_pretrained(model_path)
75
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
76
- model.eval()
77
- return tokenizer, model
78
-
79
- tokenizer_seq, model_seq = load_followup_model()
80
-
81
- def generate_followup(context, condition=None, max_tokens=64):
82
- prompt = f"Context: {context}"
83
- if condition:
84
- prompt += f"\nCondition: {condition}"
85
- prompt += "\nFollow-up question:"
86
- inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device)
87
-
88
- outputs = model_seq.generate(
89
- **inputs,
90
- max_new_tokens=max_tokens,
91
- do_sample=True,
92
- top_k=50,
93
- top_p=0.95,
94
- temperature=0.9,
95
- num_return_sequences=1
96
- )
97
- return tokenizer_seq.decode(outputs[0], skip_special_tokens=True)
98
-
99
- # -------------------- Streamlit UI --------------------
100
-
101
- st.title("Tenant-Landlord Query Classifier + Claim Assistant")
102
-
103
- if "prev_input" not in st.session_state:
104
- st.session_state.prev_input = ""
105
-
106
- if "context" not in st.session_state:
107
- st.session_state.context = ""
108
-
109
- user_input = st.text_input("Enter your query:")
110
-
111
- if user_input:
112
- with st.spinner("Classifying your query..."):
113
- combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input
114
- top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input)
115
-
116
- if probability > 60:
117
- st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**")
118
- st.write("Top 3 predictions:")
119
- for cls, prob in zip(top3_classes, formatted):
120
- st.write(f"- {cls}: {prob}")
121
- st.session_state.context = combined_input
122
- st.session_state.prev_input = "" # Reset
123
- else:
124
- st.warning("Confidence is low. Please elaborate your query more.")
125
- st.session_state.prev_input = combined_input
126
-
127
- # -------------------- Ask to Make a Claim --------------------
128
-
129
- if st.session_state.context:
130
- make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"])
131
-
132
- if make_claim == "No":
133
- st.info("Thank you! No claim will be made.")
134
- st.stop()
135
-
136
- elif make_claim == "Yes":
137
- st.subheader("Claim Assistant - Answer Follow-up Questions")
138
-
139
- if "followup_count" not in st.session_state:
140
- st.session_state.followup_count = 0
141
- if "questions" not in st.session_state:
142
- st.session_state.questions = []
143
- if "responses" not in st.session_state:
144
- st.session_state.responses = []
145
-
146
- # Generate new follow-up question if needed
147
- if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5:
148
- with st.spinner("Generating follow-up question..."):
149
- new_question = generate_followup(st.session_state.context)
150
- st.session_state.questions.append(new_question)
151
-
152
- # Render follow-up questions and collect responses
153
- for i in range(len(st.session_state.questions)):
154
- st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}")
155
- response_key = f"response_input_{i}"
156
- response = st.text_input(f"Your response to question {i+1}:", key=response_key)
157
-
158
- if response and len(st.session_state.responses) <= i:
159
- st.session_state.responses.append(response)
160
- st.session_state.context += " " + response
161
- st.session_state.followup_count += 1
162
-
163
- if st.session_state.followup_count >= 5:
164
- st.success("All follow-up questions answered. Your claim has been registered.")
165
 
 
7
  import asyncio
8
  import sys
9
 
10
+ # if sys.platform == "win32":
11
+ # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
12
 
13
  # import zipfile
14
  # zip_path = "bert-base-uncased.zip"
 
27
  for file in tokenizer_files:
28
  if os.path.exists(file):
29
  shutil.move(file, os.path.join(folder_name, file))
30
+ prepare_tokenizer_folder()
31
  # -------------------- Part 1: Prediction --------------------
32
+
33
+ # @st.cache_resource
34
+ # def load_prediction_model():
35
+ # prepare_tokenizer_folder()
36
+ # # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
37
+ # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
38
+ # with open('label_encoder_new.pkl', 'rb') as f:
39
+ # label_encoder = pickle.load(f)
40
+ # id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)}
41
+
42
+ # # model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer', num_labels=len(label_encoder.classes_))
43
+ # model = AutoModelForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer')
44
+ # # model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict'])
45
+ # model.eval()
46
+ # return tokenizer, model, id_to_class
47
+
48
+ # tokenizer_cls, model_cls, id_to_class = load_prediction_model()
49
+
50
+ # def preprocess_texts(texts):
51
+ # return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
52
+
53
+ # def predict(text):
54
+ # inputs = preprocess_texts(text)
55
+ # with torch.no_grad():
56
+ # outputs = model_cls(**inputs)
57
+ # probabilities = F.softmax(outputs.logits, dim=1)
58
+ # top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1)
59
+
60
+ # top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]]
61
+ # top3_probs = top3_probs[0] * 100
62
+ # top3_probs_np = top3_probs.cpu().numpy()
63
+ # formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np]
64
+
65
+ # prediction = top3_class_names[0]
66
+ # probability = top3_probs_np[0]
67
+
68
+ # return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability
69
+
70
+ # # -------------------- Part 2: Follow-up Generator --------------------
71
+
72
+ # @st.cache_resource
73
+ # def load_followup_model():
74
+ # model_path = "Divyanshu04/Insurance_claim_followup_model" # Adjust path as needed
75
+ # tokenizer = AutoTokenizer.from_pretrained(model_path)
76
+ # model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
77
+ # model.eval()
78
+ # return tokenizer, model
79
+
80
+ # tokenizer_seq, model_seq = load_followup_model()
81
+
82
+ # def generate_followup(context, condition=None, max_tokens=64):
83
+ # prompt = f"Context: {context}"
84
+ # if condition:
85
+ # prompt += f"\nCondition: {condition}"
86
+ # prompt += "\nFollow-up question:"
87
+ # inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device)
88
+
89
+ # outputs = model_seq.generate(
90
+ # **inputs,
91
+ # max_new_tokens=max_tokens,
92
+ # do_sample=True,
93
+ # top_k=50,
94
+ # top_p=0.95,
95
+ # temperature=0.9,
96
+ # num_return_sequences=1
97
+ # )
98
+ # return tokenizer_seq.decode(outputs[0], skip_special_tokens=True)
99
+
100
+ # # -------------------- Streamlit UI --------------------
101
+
102
+ # st.title("Tenant-Landlord Query Classifier + Claim Assistant")
103
+
104
+ # if "prev_input" not in st.session_state:
105
+ # st.session_state.prev_input = ""
106
+
107
+ # if "context" not in st.session_state:
108
+ # st.session_state.context = ""
109
+
110
+ # user_input = st.text_input("Enter your query:")
111
+
112
+ # if user_input:
113
+ # with st.spinner("Classifying your query..."):
114
+ # combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input
115
+ # top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input)
116
+
117
+ # if probability > 60:
118
+ # st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**")
119
+ # st.write("Top 3 predictions:")
120
+ # for cls, prob in zip(top3_classes, formatted):
121
+ # st.write(f"- {cls}: {prob}")
122
+ # st.session_state.context = combined_input
123
+ # st.session_state.prev_input = "" # Reset
124
+ # else:
125
+ # st.warning("Confidence is low. Please elaborate your query more.")
126
+ # st.session_state.prev_input = combined_input
127
+
128
+ # # -------------------- Ask to Make a Claim --------------------
129
+
130
+ # if st.session_state.context:
131
+ # make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"])
132
+
133
+ # if make_claim == "No":
134
+ # st.info("Thank you! No claim will be made.")
135
+ # st.stop()
136
+
137
+ # elif make_claim == "Yes":
138
+ # st.subheader("Claim Assistant - Answer Follow-up Questions")
139
+
140
+ # if "followup_count" not in st.session_state:
141
+ # st.session_state.followup_count = 0
142
+ # if "questions" not in st.session_state:
143
+ # st.session_state.questions = []
144
+ # if "responses" not in st.session_state:
145
+ # st.session_state.responses = []
146
+
147
+ # # Generate new follow-up question if needed
148
+ # if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5:
149
+ # with st.spinner("Generating follow-up question..."):
150
+ # new_question = generate_followup(st.session_state.context)
151
+ # st.session_state.questions.append(new_question)
152
+
153
+ # # Render follow-up questions and collect responses
154
+ # for i in range(len(st.session_state.questions)):
155
+ # st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}")
156
+ # response_key = f"response_input_{i}"
157
+ # response = st.text_input(f"Your response to question {i+1}:", key=response_key)
158
+
159
+ # if response and len(st.session_state.responses) <= i:
160
+ # st.session_state.responses.append(response)
161
+ # st.session_state.context += " " + response
162
+ # st.session_state.followup_count += 1
163
+
164
+ # if st.session_state.followup_count >= 5:
165
+ # st.success("All follow-up questions answered. Your claim has been registered.")
166