import streamlit as st import torch import os import pickle import torch.nn.functional as F from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM,AutoModelForSequenceClassification import asyncio import sys if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) #-------------------- Part 1: Prediction -------------------- @st.cache_resource def load_prediction_model(): # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') with open('label_encoder_new.pkl', 'rb') as f: label_encoder = pickle.load(f) id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)} # model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer', num_labels=len(label_encoder.classes_)) model = AutoModelForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer') # model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict']) model.eval() return tokenizer, model, id_to_class tokenizer_cls, model_cls, id_to_class = load_prediction_model() def preprocess_texts(texts): return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt') def predict(text): inputs = preprocess_texts(text) with torch.no_grad(): outputs = model_cls(**inputs) probabilities = F.softmax(outputs.logits, dim=1) top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1) top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]] top3_probs = top3_probs[0] * 100 top3_probs_np = top3_probs.cpu().numpy() formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np] prediction = top3_class_names[0] probability = top3_probs_np[0] return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability # -------------------- Part 2: Follow-up Generator -------------------- @st.cache_resource def load_followup_model(): model_path = "Divyanshu04/Insurance_claim_followup_model" # Adjust path as needed tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSeq2SeqLM.from_pretrained(model_path) model.eval() return tokenizer, model tokenizer_seq, model_seq = load_followup_model() def generate_followup(context, condition=None, max_tokens=64): prompt = f"Context: {context}" if condition: prompt += f"\nCondition: {condition}" prompt += "\nFollow-up question:" inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device) outputs = model_seq.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, top_k=50, top_p=0.95, temperature=0.9, num_return_sequences=1 ) return tokenizer_seq.decode(outputs[0], skip_special_tokens=True) # -------------------- Streamlit UI -------------------- st.title("Tenant-Landlord Query Classifier + Claim Assistant") if "prev_input" not in st.session_state: st.session_state.prev_input = "" if "context" not in st.session_state: st.session_state.context = "" user_input = st.text_input("Enter your query:") if user_input: with st.spinner("Classifying your query..."): combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input) if probability > 60: st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**") st.write("Top 3 predictions:") for cls, prob in zip(top3_classes, formatted): st.write(f"- {cls}: {prob}") st.session_state.context = combined_input st.session_state.prev_input = "" # Reset else: st.warning("Confidence is low. Please elaborate your query more.") st.session_state.prev_input = combined_input # -------------------- Ask to Make a Claim -------------------- if st.session_state.context: make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"]) if make_claim == "No": st.info("Thank you! No claim will be made.") st.stop() elif make_claim == "Yes": st.subheader("Claim Assistant - Answer Follow-up Questions") if "followup_count" not in st.session_state: st.session_state.followup_count = 0 if "questions" not in st.session_state: st.session_state.questions = [] if "responses" not in st.session_state: st.session_state.responses = [] # Generate new follow-up question if needed if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5: with st.spinner("Generating follow-up question..."): new_question = generate_followup(st.session_state.context) st.session_state.questions.append(new_question) # Render follow-up questions and collect responses for i in range(len(st.session_state.questions)): st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}") response_key = f"response_input_{i}" response = st.text_input(f"Your response to question {i+1}:", key=response_key) if response and len(st.session_state.responses) <= i: st.session_state.responses.append(response) st.session_state.context += " " + response st.session_state.followup_count += 1 if st.session_state.followup_count >= 5: st.success("All follow-up questions answered. Your claim has been registered.")