File size: 6,047 Bytes
fdc9e3b
 
38a08f3
81357c2
 
fdc9e3b
81357c2
 
 
 
 
 
 
38a08f3
 
81357c2
 
 
 
38a08f3
 
 
 
 
 
 
81357c2
38a08f3
 
81357c2
fdc9e3b
 
38a08f3
fdc9e3b
 
 
 
 
 
 
 
 
38a08f3
 
 
fdc9e3b
38a08f3
 
fdc9e3b
 
 
38a08f3
 
 
 
 
fdc9e3b
38a08f3
 
 
 
fdc9e3b
38a08f3
81357c2
 
 
38a08f3
fdc9e3b
81357c2
 
 
38a08f3
 
fdc9e3b
81357c2
 
 
fdc9e3b
 
81357c2
 
 
38a08f3
a36d7cd
38a08f3
fdc9e3b
06082e2
 
fdc9e3b
 
 
a36d7cd
81357c2
fdc9e3b
38a08f3
 
 
d9f4687
38a08f3
 
 
fdc9e3b
38a08f3
 
 
 
 
 
fdc9e3b
38a08f3
a36d7cd
38a08f3
fdc9e3b
81357c2
38a08f3
fdc9e3b
 
 
 
38a08f3
fdc9e3b
 
38a08f3
fdc9e3b
38a08f3
 
 
 
fdc9e3b
38a08f3
 
 
 
 
fdc9e3b
 
 
38a08f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
### FINAL APP.PY FOR HUGGING FACE USING THE IBM GRANITE MODEL ###

import streamlit as st
import torch
import fitz  # PyMuPDF
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

# For Fairness Audit
import pandas as pd
from aif360.datasets import StandardDataset
from aif360.metrics import BinaryLabelDatasetMetric

# --- Page Configuration ---
st.set_page_config(
    page_title="Sahay AI 🇮🇳",
    page_icon="🤖",
    layout="wide",
    initial_sidebar_state="expanded"
)

# --- Caching for Performance ---
@st.cache_resource
def load_llm():
    """Loads the IBM Granite LLM, ensuring it runs on a GPU."""
    llm_model_name = "ibm-granite/granite-3.3-8b-instruct"
    
    # This check is crucial. The app will stop if no GPU is found.
    if not torch.cuda.is_available():
        raise RuntimeError("Hardware Error: This application requires a GPU to run the IBM Granite model. Please select a GPU hardware tier in your Space settings (e.g., T4 small).")

    model = AutoModelForCausalLM.from_pretrained(
        llm_model_name,
        torch_dtype=torch.bfloat16,
        load_in_4bit=True  # 4-bit quantization to save memory
    )
    tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
    
    pipe = pipeline(
        "text-generation", 
        model=model, 
        tokenizer=tokenizer, 
        max_new_tokens=512,
        temperature=0.1,
        device=0  # Force the pipeline to use the first available GPU
    )
    return HuggingFacePipeline(pipeline=pipe)

@st.cache_resource
def load_and_process_pdf(pdf_path):
    """Loads and embeds the PDF using IBM's multilingual model."""
    try:
        doc = fitz.open(pdf_path)
        text = "".join(page.get_text() for page in doc)
    except Exception as e:
        st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.")
        return None

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.create_documents([text])
    
    embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
    vector_db = FAISS.from_documents(docs, embedding_model)
    return vector_db

# --- Conversational Chain ---
def create_conversational_chain(_llm, _vector_db):
    prompt_template = """You are a polite AI assistant for the PM-KISAN scheme... (rest of prompt)"""
    QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
    chain = ConversationalRetrievalChain.from_llm(
        llm=_llm, retriever=_vector_db.as_retriever(), memory=memory,
        return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
    )
    return chain

# --- IBM AIF360 Fairness Audit ---
def run_fairness_audit():
    st.subheader("🤖 IBM AIF360 - Fairness Audit")
    df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']})
    df_for_aif = pd.DataFrame()
    df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
    df_for_aif['favorable_outcome'] = [1, 1, 1, 1]
    aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
                                  protected_attribute_names=['gender'], privileged_classes=[[1]])
    metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
    spd = metric.statistical_parity_difference()
    st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")

# --- Main Application UI ---
if __name__ == "__main__":
    
    with st.sidebar:
        st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
        st.title("🇮🇳 Sahay AI")
        st.markdown("An AI assistant for the **PM-KISAN** scheme, built on **IBM's Granite** foundation models.")
        if st.button("Run Fairness Audit", use_container_width=True):
            st.session_state.run_audit = True

    st.header("Chat with Sahay AI 💬")

    if st.session_state.get('run_audit', False):
        run_fair_audit()
        st.session_state.run_audit = False
    
    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]

    if "qa_chain" not in st.session_state:
        with st.spinner("🚀 Waking up the IBM Granite Model... This may take several minutes on a GPU."):
            try:
                llm = load_llm()
                vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
                st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
            except RuntimeError as e:
                st.error(e) # This will display the "Hardware Error" message from load_llm()
                st.stop()
    
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("Ask a question..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
        with st.chat_message("assistant"):
            with st.spinner("🧠 Thinking..."):
                result = st.session_state.qa_chain.invoke({"question": prompt})
                response = result["answer"]
                st.markdown(response)
        st.session_state.messages.append({"role": "assistant", "content": response})