File size: 5,991 Bytes
f767c3b
07dbf6e
 
 
 
b6452be
07dbf6e
 
f767c3b
07dbf6e
 
 
d5dcfe2
4dba1b2
 
 
 
 
 
 
07dbf6e
 
 
 
4eddfbb
1c2dc98
07dbf6e
 
 
cfeb2b5
07dbf6e
cfeb2b5
07dbf6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1683ee2
07dbf6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import streamlit as st
import torch
import os
import pickle
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import asyncio
import sys

if sys.platform == "win32":
    asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())


# from huggingface_hub import login

# token = os.getenv("HF_TOKEN")  # Read from environment
# if token:
#     login(token=token)
# else:
#     raise ValueError("HF_TOKEN environment variable is not set.")
# -------------------- Part 1: Prediction --------------------

@st.cache_resource
def load_prediction_model():
    tokenizer = BertTokenizer.from_pretrained('https://huggingface.co/bert-base-uncased')
    with open('src/label_encoder.pkl', 'rb') as f:
        label_encoder = pickle.load(f)
    id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)}

    model = BertForSequenceClassification.from_pretrained('Divyanshu04/Issue_categorizerv2')
    # model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict'])
    # model.eval()
    return tokenizer, model, id_to_class

tokenizer_cls, model_cls, id_to_class = load_prediction_model()

def preprocess_texts(texts):
    return tokenizer_cls(texts, padding='max_length', truncation=True, max_length=128, return_tensors='pt')

def predict(text):
    inputs = preprocess_texts(text)
    with torch.no_grad():
        outputs = model_cls(**inputs)
        probabilities = F.softmax(outputs.logits, dim=1)
        top3_probs, top3_classes = torch.topk(probabilities, k=3, dim=1)

    top3_class_names = [id_to_class[idx.item()] for idx in top3_classes[0]]
    top3_probs = top3_probs[0] * 100
    top3_probs_np = top3_probs.cpu().numpy()
    formatted_percentages = [f"{p:.4f}%" for p in top3_probs_np]

    prediction = top3_class_names[0]
    probability = top3_probs_np[0]

    return top3_class_names, top3_probs_np, formatted_percentages, prediction, probability

# -------------------- Part 2: Follow-up Generator --------------------

@st.cache_resource
def load_followup_model():
    model_path = "https://huggingface.co/Divyanshu04/Insurance_claim_followup_model"  # Adjust path as needed
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    model.eval()
    return tokenizer, model

tokenizer_seq, model_seq = load_followup_model()

def generate_followup(context, condition=None, max_tokens=64):
    prompt = f"Context: {context}"
    if condition:
        prompt += f"\nCondition: {condition}"
    prompt += "\nFollow-up question:"
    inputs = tokenizer_seq(prompt, return_tensors="pt", padding=True).to(model_seq.device)

    outputs = model_seq.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=0.9,
        num_return_sequences=1
    )
    return tokenizer_seq.decode(outputs[0], skip_special_tokens=True)

# -------------------- Streamlit UI --------------------

st.title("Tenant-Landlord Query Classifier + Claim Assistant")

if "prev_input" not in st.session_state:
    st.session_state.prev_input = ""

if "context" not in st.session_state:
    st.session_state.context = ""

user_input = st.text_input("Enter your query:")

if user_input:
    with st.spinner("Classifying your query..."):
        combined_input = st.session_state.prev_input + " " + user_input if st.session_state.prev_input else user_input
        top3_classes, top3_probs, formatted, prediction, probability = predict(combined_input)

        if probability > 60:
            st.success(f"Prediction: **{prediction}** with confidence **{probability:.2f}%**")
            st.write("Top 3 predictions:")
            for cls, prob in zip(top3_classes, formatted):
                st.write(f"- {cls}: {prob}")
            st.session_state.context = combined_input
            st.session_state.prev_input = ""  # Reset
        else:
            st.warning("Confidence is low. Please elaborate your query more.")
            st.session_state.prev_input = combined_input

# -------------------- Ask to Make a Claim --------------------

if st.session_state.context:
    make_claim = st.radio("Do you want to make a claim?", ["Yes", "No"])

    if make_claim == "No":
        st.info("Thank you! No claim will be made.")
        st.stop()

    elif make_claim == "Yes":
        st.subheader("Claim Assistant - Answer Follow-up Questions")

        if "followup_count" not in st.session_state:
            st.session_state.followup_count = 0
        if "questions" not in st.session_state:
            st.session_state.questions = []
        if "responses" not in st.session_state:
            st.session_state.responses = []

        # Generate new follow-up question if needed
        if len(st.session_state.questions) <= st.session_state.followup_count and st.session_state.followup_count < 5:
            with st.spinner("Generating follow-up question..."):
                new_question = generate_followup(st.session_state.context)
            st.session_state.questions.append(new_question)

        # Render follow-up questions and collect responses
        for i in range(len(st.session_state.questions)):
            st.markdown(f"**Follow-up Question {i+1}:** {st.session_state.questions[i]}")
            response_key = f"response_input_{i}"
            response = st.text_input(f"Your response to question {i+1}:", key=response_key)

            if response and len(st.session_state.responses) <= i:
                st.session_state.responses.append(response)
                st.session_state.context += " " + response
                st.session_state.followup_count += 1

        if st.session_state.followup_count >= 5:
            st.success("All follow-up questions answered. Your claim has been registered.")