Spaces:
Configuration error
Configuration error
| ### FINAL APP.PY FOR HUGGING FACE USING THE IBM GRANITE MODEL ### | |
| import streamlit as st | |
| import torch | |
| import fitz # PyMuPDF | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| # For Fairness Audit | |
| import pandas as pd | |
| from aif360.datasets import StandardDataset | |
| from aif360.metrics import BinaryLabelDatasetMetric | |
| # --- Page Configuration --- | |
| st.set_page_config( | |
| page_title="Sahay AI 🇮🇳", | |
| page_icon="🤖", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # --- Caching for Performance --- | |
| def load_llm(): | |
| """Loads the IBM Granite LLM, ensuring it runs on a GPU.""" | |
| llm_model_name = "ibm-granite/granite-3.3-8b-instruct" | |
| # This check is crucial. The app will stop if no GPU is found. | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("Hardware Error: This application requires a GPU to run the IBM Granite model. Please select a GPU hardware tier in your Space settings (e.g., T4 small).") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| llm_model_name, | |
| torch_dtype=torch.bfloat16, | |
| load_in_4bit=True # 4-bit quantization to save memory | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=512, | |
| temperature=0.1, | |
| device=0 # Force the pipeline to use the first available GPU | |
| ) | |
| return HuggingFacePipeline(pipeline=pipe) | |
| def load_and_process_pdf(pdf_path): | |
| """Loads and embeds the PDF using IBM's multilingual model.""" | |
| try: | |
| doc = fitz.open(pdf_path) | |
| text = "".join(page.get_text() for page in doc) | |
| except Exception as e: | |
| st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.") | |
| return None | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) | |
| docs = text_splitter.create_documents([text]) | |
| embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual") | |
| vector_db = FAISS.from_documents(docs, embedding_model) | |
| return vector_db | |
| # --- Conversational Chain --- | |
| def create_conversational_chain(_llm, _vector_db): | |
| prompt_template = """You are a polite AI assistant for the PM-KISAN scheme... (rest of prompt)""" | |
| QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer') | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=_llm, retriever=_vector_db.as_retriever(), memory=memory, | |
| return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT} | |
| ) | |
| return chain | |
| # --- IBM AIF360 Fairness Audit --- | |
| def run_fairness_audit(): | |
| st.subheader("🤖 IBM AIF360 - Fairness Audit") | |
| df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']}) | |
| df_for_aif = pd.DataFrame() | |
| df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0}) | |
| df_for_aif['favorable_outcome'] = [1, 1, 1, 1] | |
| aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1], | |
| protected_attribute_names=['gender'], privileged_classes=[[1]]) | |
| metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}]) | |
| spd = metric.statistical_parity_difference() | |
| st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}") | |
| # --- Main Application UI --- | |
| if __name__ == "__main__": | |
| with st.sidebar: | |
| st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100) | |
| st.title("🇮🇳 Sahay AI") | |
| st.markdown("An AI assistant for the **PM-KISAN** scheme, built on **IBM's Granite** foundation models.") | |
| if st.button("Run Fairness Audit", use_container_width=True): | |
| st.session_state.run_audit = True | |
| st.header("Chat with Sahay AI 💬") | |
| if st.session_state.get('run_audit', False): | |
| run_fair_audit() | |
| st.session_state.run_audit = False | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}] | |
| if "qa_chain" not in st.session_state: | |
| with st.spinner("🚀 Waking up the IBM Granite Model... This may take several minutes on a GPU."): | |
| try: | |
| llm = load_llm() | |
| vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF") | |
| st.session_state.qa_chain = create_conversational_chain(llm, vector_db) | |
| except RuntimeError as e: | |
| st.error(e) # This will display the "Hardware Error" message from load_llm() | |
| st.stop() | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("Ask a question..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("🧠 Thinking..."): | |
| result = st.session_state.qa_chain.invoke({"question": prompt}) | |
| response = result["answer"] | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) |