Divyanshu04 commited on
Commit
07dbf6e
·
verified ·
1 Parent(s): 6b858a5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +143 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,145 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import os
4
+ import pickle
5
+ import torch.nn.functional as F
6
+ from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM
7
+ import asyncio
8
+ import sys
9
 
10
+ if sys.platform == "win32":
11
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
12
+
13
+ # -------------------- Part 1: Prediction --------------------
14
+
15
+ @st.cache_resource
16
+ def load_prediction_model():
17
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
18
+ with open('label_encoder.pkl', 'rb') as f:
19
+ label_encoder = pickle.load(f)
20
+ id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)}
21
+
22
+ model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer', num_labels=len(label_encoder.classes_))
23
+ # model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict'])
24
+ model.eval()
25
+ return tokenizer, model, id_to_class
26
+
27
+ tokenizer_cls, model_cls, id_to_class = load_prediction_model()
28
+
29
+ def preprocess_texts(texts):
30
+ return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
31
+
32
+ def predict(text):
33
+ inputs = preprocess_texts(text)
34
+ with torch.no_grad():
35
+ outputs = model_cls(**inputs)
36
+ probabilities = F.softmax(outputs.logits, dim=1)
37
+ top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1)
38
+
39
+ top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]]
40
+ top3_probs = top3_probs[0] * 100
41
+ top3_probs_np = top3_probs.cpu().numpy()
42
+ formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np]
43
+
44
+ prediction = top3_class_names[0]
45
+ probability = top3_probs_np[0]
46
+
47
+ return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability
48
+
49
+ # -------------------- Part 2: Follow-up Generator --------------------
50
+
51
+ @st.cache_resource
52
+ def load_followup_model():
53
+ model_path = "Divyanshu04/Insurance_claim_followup_model" # Adjust path as needed
54
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
55
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
56
+ model.eval()
57
+ return tokenizer, model
58
+
59
+ tokenizer_seq, model_seq = load_followup_model()
60
+
61
+ def generate_followup(context, condition=None, max_tokens=64):
62
+ prompt = f"Context: {context}"
63
+ if condition:
64
+ prompt += f"\nCondition: {condition}"
65
+ prompt += "\nFollow-up question:"
66
+ inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device)
67
+
68
+ outputs = model_seq.generate(
69
+ **inputs,
70
+ max_new_tokens=max_tokens,
71
+ do_sample=True,
72
+ top_k=50,
73
+ top_p=0.95,
74
+ temperature=0.9,
75
+ num_return_sequences=1
76
+ )
77
+ return tokenizer_seq.decode(outputs[0], skip_special_tokens=True)
78
+
79
+ # -------------------- Streamlit UI --------------------
80
+
81
+ st.title("Tenant-Landlord Query Classifier + Claim Assistant")
82
+
83
+ if "prev_input" not in st.session_state:
84
+ st.session_state.prev_input = ""
85
+
86
+ if "context" not in st.session_state:
87
+ st.session_state.context = ""
88
+
89
+ user_input = st.text_input("Enter your query:")
90
+
91
+ if user_input:
92
+ with st.spinner("Classifying your query..."):
93
+ combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input
94
+ top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input)
95
+
96
+ if probability > 60:
97
+ st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**")
98
+ st.write("Top 3 predictions:")
99
+ for cls, prob in zip(top3_classes, formatted):
100
+ st.write(f"- {cls}: {prob}")
101
+ st.session_state.context = combined_input
102
+ st.session_state.prev_input = "" # Reset
103
+ else:
104
+ st.warning("Confidence is low. Please elaborate your query more.")
105
+ st.session_state.prev_input = combined_input
106
+
107
+ # -------------------- Ask to Make a Claim --------------------
108
+
109
+ if st.session_state.context:
110
+ make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"])
111
+
112
+ if make_claim == "No":
113
+ st.info("Thank you! No claim will be made.")
114
+ st.stop()
115
+
116
+ elif make_claim == "Yes":
117
+ st.subheader("Claim Assistant - Answer Follow-up Questions")
118
+
119
+ if "followup_count" not in st.session_state:
120
+ st.session_state.followup_count = 0
121
+ if "questions" not in st.session_state:
122
+ st.session_state.questions = []
123
+ if "responses" not in st.session_state:
124
+ st.session_state.responses = []
125
+
126
+ # Generate new follow-up question if needed
127
+ if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5:
128
+ with st.spinner("Generating follow-up question..."):
129
+ new_question = generate_followup(st.session_state.context)
130
+ st.session_state.questions.append(new_question)
131
+
132
+ # Render follow-up questions and collect responses
133
+ for i in range(len(st.session_state.questions)):
134
+ st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}")
135
+ response_key = f"response_input_{i}"
136
+ response = st.text_input(f"Your response to question {i+1}:", key=response_key)
137
+
138
+ if response and len(st.session_state.responses) <= i:
139
+ st.session_state.responses.append(response)
140
+ st.session_state.context += " " + response
141
+ st.session_state.followup_count += 1
142
+
143
+ if st.session_state.followup_count >= 5:
144
+ st.success("All follow-up questions answered. Your claim has been registered.")
145
+