Divyanshu04's picture
Update src/streamlit_app.py
cfeb2b5 verified
import streamlit as st
import torch
import os
import pickle
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import asyncio
import sys
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# from huggingface_hub import login
# token = os.getenv("HF_TOKEN") # Read from environment
# if token:
# login(token=token)
# else:
# raise ValueError("HF_TOKEN environment variable is not set.")
# -------------------- Part 1: Prediction --------------------
@st.cache_resource
def load_prediction_model():
tokenizer = BertTokenizer.from_pretrained('https://huggingface.co/bert-base-uncased')
with open('src/label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)}
model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizerv2')
# model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict'])
# model.eval()
return tokenizer, model, id_to_class
tokenizer_cls, model_cls, id_to_class = load_prediction_model()
def preprocess_texts(texts):
return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
def predict(text):
inputs = preprocess_texts(text)
with torch.no_grad():
outputs = model_cls(**inputs)
probabilities = F.softmax(outputs.logits, dim=1)
top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1)
top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]]
top3_probs = top3_probs[0] * 100
top3_probs_np = top3_probs.cpu().numpy()
formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np]
prediction = top3_class_names[0]
probability = top3_probs_np[0]
return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability
# -------------------- Part 2: Follow-up Generator --------------------
@st.cache_resource
def load_followup_model():
model_path = "https://huggingface.co/Divyanshu04/Insurance_claim_followup_model" # Adjust path as needed
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
return tokenizer, model
tokenizer_seq, model_seq = load_followup_model()
def generate_followup(context, condition=None, max_tokens=64):
prompt = f"Context: {context}"
if condition:
prompt += f"\nCondition: {condition}"
prompt += "\nFollow-up question:"
inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device)
outputs = model_seq.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.9,
num_return_sequences=1
)
return tokenizer_seq.decode(outputs[0], skip_special_tokens=True)
# -------------------- Streamlit UI --------------------
st.title("Tenant-Landlord Query Classifier + Claim Assistant")
if "prev_input" not in st.session_state:
st.session_state.prev_input = ""
if "context" not in st.session_state:
st.session_state.context = ""
user_input = st.text_input("Enter your query:")
if user_input:
with st.spinner("Classifying your query..."):
combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input
top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input)
if probability > 60:
st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**")
st.write("Top 3 predictions:")
for cls, prob in zip(top3_classes, formatted):
st.write(f"- {cls}: {prob}")
st.session_state.context = combined_input
st.session_state.prev_input = "" # Reset
else:
st.warning("Confidence is low. Please elaborate your query more.")
st.session_state.prev_input = combined_input
# -------------------- Ask to Make a Claim --------------------
if st.session_state.context:
make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"])
if make_claim == "No":
st.info("Thank you! No claim will be made.")
st.stop()
elif make_claim == "Yes":
st.subheader("Claim Assistant - Answer Follow-up Questions")
if "followup_count" not in st.session_state:
st.session_state.followup_count = 0
if "questions" not in st.session_state:
st.session_state.questions = []
if "responses" not in st.session_state:
st.session_state.responses = []
# Generate new follow-up question if needed
if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5:
with st.spinner("Generating follow-up question..."):
new_question = generate_followup(st.session_state.context)
st.session_state.questions.append(new_question)
# Render follow-up questions and collect responses
for i in range(len(st.session_state.questions)):
st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}")
response_key = f"response_input_{i}"
response = st.text_input(f"Your response to question {i+1}:", key=response_key)
if response and len(st.session_state.responses) <= i:
st.session_state.responses.append(response)
st.session_state.context += " " + response
st.session_state.followup_count += 1
if st.session_state.followup_count >= 5:
st.success("All follow-up questions answered. Your claim has been registered.")