Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import os | |
import pickle | |
import torch.nn.functional as F | |
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM,AutoModelForSequenceClassification | |
import asyncio | |
import sys | |
if sys.platform == "win32": | |
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | |
#-------------------- Part 1: Prediction -------------------- | |
def load_prediction_model(): | |
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
with open('label_encoder_new.pkl', 'rb') as f: | |
label_encoder = pickle.load(f) | |
id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)} | |
# model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer', num_labels=len(label_encoder.classes_)) | |
model = AutoModelForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizer') | |
# model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict']) | |
model.eval() | |
return tokenizer, model, id_to_class | |
tokenizer_cls, model_cls, id_to_class = load_prediction_model() | |
def preprocess_texts(texts): | |
return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
def predict(text): | |
inputs = preprocess_texts(text) | |
with torch.no_grad(): | |
outputs = model_cls(**inputs) | |
probabilities = F.softmax(outputs.logits, dim=1) | |
top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1) | |
top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]] | |
top3_probs = top3_probs[0] * 100 | |
top3_probs_np = top3_probs.cpu().numpy() | |
formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np] | |
prediction = top3_class_names[0] | |
probability = top3_probs_np[0] | |
return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability | |
# -------------------- Part 2: Follow-up Generator -------------------- | |
def load_followup_model(): | |
model_path = "Divyanshu04/Insurance_claim_followup_model" # Adjust path as needed | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | |
model.eval() | |
return tokenizer, model | |
tokenizer_seq, model_seq = load_followup_model() | |
def generate_followup(context, condition=None, max_tokens=64): | |
prompt = f"Context: {context}" | |
if condition: | |
prompt += f"\nCondition: {condition}" | |
prompt += "\nFollow-up question:" | |
inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device) | |
outputs = model_seq.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.9, | |
num_return_sequences=1 | |
) | |
return tokenizer_seq.decode(outputs[0], skip_special_tokens=True) | |
# -------------------- Streamlit UI -------------------- | |
st.title("Tenant-Landlord Query Classifier + Claim Assistant") | |
if "prev_input" not in st.session_state: | |
st.session_state.prev_input = "" | |
if "context" not in st.session_state: | |
st.session_state.context = "" | |
user_input = st.text_input("Enter your query:") | |
if user_input: | |
with st.spinner("Classifying your query..."): | |
combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input | |
top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input) | |
if probability > 60: | |
st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**") | |
st.write("Top 3 predictions:") | |
for cls, prob in zip(top3_classes, formatted): | |
st.write(f"- {cls}: {prob}") | |
st.session_state.context = combined_input | |
st.session_state.prev_input = "" # Reset | |
else: | |
st.warning("Confidence is low. Please elaborate your query more.") | |
st.session_state.prev_input = combined_input | |
# -------------------- Ask to Make a Claim -------------------- | |
if st.session_state.context: | |
make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"]) | |
if make_claim == "No": | |
st.info("Thank you! No claim will be made.") | |
st.stop() | |
elif make_claim == "Yes": | |
st.subheader("Claim Assistant - Answer Follow-up Questions") | |
if "followup_count" not in st.session_state: | |
st.session_state.followup_count = 0 | |
if "questions" not in st.session_state: | |
st.session_state.questions = [] | |
if "responses" not in st.session_state: | |
st.session_state.responses = [] | |
# Generate new follow-up question if needed | |
if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5: | |
with st.spinner("Generating follow-up question..."): | |
new_question = generate_followup(st.session_state.context) | |
st.session_state.questions.append(new_question) | |
# Render follow-up questions and collect responses | |
for i in range(len(st.session_state.questions)): | |
st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}") | |
response_key = f"response_input_{i}" | |
response = st.text_input(f"Your response to question {i+1}:", key=response_key) | |
if response and len(st.session_state.responses) <= i: | |
st.session_state.responses.append(response) | |
st.session_state.context += " " + response | |
st.session_state.followup_count += 1 | |
if st.session_state.followup_count >= 5: | |
st.success("All follow-up questions answered. Your claim has been registered.") | |