|
import streamlit as st |
|
import torch |
|
import os |
|
import pickle |
|
import torch.nn.functional as F |
|
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification |
|
import asyncio |
|
import sys |
|
|
|
if sys.platform == "win32": |
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_prediction_model(): |
|
tokenizer = BertTokenizer.from_pretrained('https://huggingface.co/bert-base-uncased') |
|
with open('src/label_encoder.pkl', 'rb') as f: |
|
label_encoder = pickle.load(f) |
|
id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)} |
|
|
|
model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizerv2') |
|
|
|
|
|
return tokenizer, model, id_to_class |
|
|
|
tokenizer_cls, model_cls, id_to_class = load_prediction_model() |
|
|
|
def preprocess_texts(texts): |
|
return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt') |
|
|
|
def predict(text): |
|
inputs = preprocess_texts(text) |
|
with torch.no_grad(): |
|
outputs = model_cls(**inputs) |
|
probabilities = F.softmax(outputs.logits, dim=1) |
|
top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1) |
|
|
|
top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]] |
|
top3_probs = top3_probs[0] * 100 |
|
top3_probs_np = top3_probs.cpu().numpy() |
|
formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np] |
|
|
|
prediction = top3_class_names[0] |
|
probability = top3_probs_np[0] |
|
|
|
return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_followup_model(): |
|
model_path = "https://huggingface.co/Divyanshu04/Insurance_claim_followup_model" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
|
model.eval() |
|
return tokenizer, model |
|
|
|
tokenizer_seq, model_seq = load_followup_model() |
|
|
|
def generate_followup(context, condition=None, max_tokens=64): |
|
prompt = f"Context: {context}" |
|
if condition: |
|
prompt += f"\nCondition: {condition}" |
|
prompt += "\nFollow-up question:" |
|
inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device) |
|
|
|
outputs = model_seq.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.9, |
|
num_return_sequences=1 |
|
) |
|
return tokenizer_seq.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
st.title("Tenant-Landlord Query Classifier + Claim Assistant") |
|
|
|
if "prev_input" not in st.session_state: |
|
st.session_state.prev_input = "" |
|
|
|
if "context" not in st.session_state: |
|
st.session_state.context = "" |
|
|
|
user_input = st.text_input("Enter your query:") |
|
|
|
if user_input: |
|
with st.spinner("Classifying your query..."): |
|
combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input |
|
top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input) |
|
|
|
if probability > 60: |
|
st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**") |
|
st.write("Top 3 predictions:") |
|
for cls, prob in zip(top3_classes, formatted): |
|
st.write(f"- {cls}: {prob}") |
|
st.session_state.context = combined_input |
|
st.session_state.prev_input = "" |
|
else: |
|
st.warning("Confidence is low. Please elaborate your query more.") |
|
st.session_state.prev_input = combined_input |
|
|
|
|
|
|
|
if st.session_state.context: |
|
make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"]) |
|
|
|
if make_claim == "No": |
|
st.info("Thank you! No claim will be made.") |
|
st.stop() |
|
|
|
elif make_claim == "Yes": |
|
st.subheader("Claim Assistant - Answer Follow-up Questions") |
|
|
|
if "followup_count" not in st.session_state: |
|
st.session_state.followup_count = 0 |
|
if "questions" not in st.session_state: |
|
st.session_state.questions = [] |
|
if "responses" not in st.session_state: |
|
st.session_state.responses = [] |
|
|
|
|
|
if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5: |
|
with st.spinner("Generating follow-up question..."): |
|
new_question = generate_followup(st.session_state.context) |
|
st.session_state.questions.append(new_question) |
|
|
|
|
|
for i in range(len(st.session_state.questions)): |
|
st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}") |
|
response_key = f"response_input_{i}" |
|
response = st.text_input(f"Your response to question {i+1}:", key=response_key) |
|
|
|
if response and len(st.session_state.responses) <= i: |
|
st.session_state.responses.append(response) |
|
st.session_state.context += " " + response |
|
st.session_state.followup_count += 1 |
|
|
|
if st.session_state.followup_count >= 5: |
|
st.success("All follow-up questions answered. Your claim has been registered.") |
|
|