frozen8569 commited on
Commit
fdc9e3b
·
verified ·
1 Parent(s): 38a08f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -115
app.py CHANGED
@@ -1,7 +1,9 @@
 
 
1
  import streamlit as st
2
  import torch
3
  import fitz # PyMuPDF
4
- from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM # Import for T5 model
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -26,115 +28,70 @@ st.set_page_config(
26
  # --- Caching for Performance ---
27
  @st.cache_resource
28
  def load_llm():
29
- """
30
- Loads a smaller, CPU-friendly model (FLAN-T5-Base) for better performance
31
- on the free Hugging Face Spaces hardware.
32
- """
33
- # Using a smaller, CPU-compatible model to ensure the app is fast and responsive.
34
- llm_model_name = "google/flan-t5-base"
35
 
 
 
 
 
 
 
 
 
 
36
  tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
37
- # Use AutoModelForSeq2SeqLM for T5 models
38
- model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
39
 
40
  pipe = pipeline(
41
- "text2text-generation", # T5 models use this pipeline type
42
  model=model,
43
  tokenizer=tokenizer,
44
- max_length=512
 
 
45
  )
46
  return HuggingFacePipeline(pipeline=pipe)
47
 
48
  @st.cache_resource
49
  def load_and_process_pdf(pdf_path):
50
- """Loads, chunks, and embeds the PDF into a FAISS vector store using IBM's model."""
51
  try:
52
  doc = fitz.open(pdf_path)
53
  text = "".join(page.get_text() for page in doc)
54
- if not text:
55
- st.error("Could not extract text from the PDF.")
56
- return None
57
  except Exception as e:
58
- st.error(f"Error reading PDF file: {e}. Make sure 'PMKisanSamanNidhi.PDF' is uploaded to the Space.")
59
  return None
60
 
61
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
62
  docs = text_splitter.create_documents([text])
63
 
64
- # Still using the powerful IBM embedding model for multilingual understanding
65
- model_name = "ibm-granite/granite-embedding-278m-multilingual"
66
- embedding_model = HuggingFaceEmbeddings(model_name=model_name)
67
-
68
  vector_db = FAISS.from_documents(docs, embedding_model)
69
  return vector_db
70
 
71
  # --- Conversational Chain ---
72
  def create_conversational_chain(_llm, _vector_db):
73
- """Creates the LangChain conversational retrieval chain."""
74
- prompt_template = """You are a polite and professional AI assistant for the PM-KISAN scheme.
75
- Use the following context to answer the user's question precisely.
76
- If the question is not related to the provided context, you must state: "I can only answer questions related to the PM-KISAN scheme."
77
- Do not make up information.
78
-
79
- Context: {context}
80
- Question: {question}
81
-
82
- Helpful Answer:"""
83
-
84
  QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
85
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
86
-
87
  chain = ConversationalRetrievalChain.from_llm(
88
- llm=_llm,
89
- retriever=_vector_db.as_retriever(search_kwargs={'k': 3}),
90
- memory=memory,
91
- return_source_documents=True,
92
- combine_docs_chain_kwargs={"prompt": QA_PROMPT}
93
  )
94
  return chain
95
 
96
  # --- IBM AIF360 Fairness Audit ---
97
  def run_fairness_audit():
98
- """Performs and displays a simulated fairness audit."""
99
  st.subheader("🤖 IBM AIF360 - Fairness Audit")
100
- st.info("""
101
- This is a simulation to demonstrate how we can check for bias in our information retriever.
102
- A fair system should provide equally good information to all demographic groups.
103
- """)
104
- test_data = {
105
- 'query': ["loan for my farm", "help for my crops", "scheme for women", "grant for female farmer"],
106
- 'gender_text': ['male', 'male', 'female', 'female'],
107
- 'expected_doc': ['doc1', 'doc1', 'doc2', 'doc2']
108
- }
109
- df_display = pd.DataFrame(test_data)
110
-
111
- def simulate_retriever(query):
112
- return "doc2" if "women" in query or "female" in query else "doc1"
113
- df_display['retrieved_doc'] = df_display['query'].apply(simulate_retriever)
114
- df_display['favorable_outcome'] = (df_display['retrieved_doc'] == df_display['expected_doc']).astype(int)
115
-
116
  df_for_aif = pd.DataFrame()
117
  df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
118
- df_for_aif['favorable_outcome'] = df_display['favorable_outcome']
119
-
120
- aif_dataset = StandardDataset(df_for_aif,
121
- label_name='favorable_outcome',
122
- favorable_classes=[1],
123
- protected_attribute_names=['gender'],
124
- privileged_classes=[[1]])
125
-
126
  metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
127
  spd = metric.statistical_parity_difference()
128
-
129
- st.markdown("---")
130
- col1, col2 = st.columns(2)
131
- with col1:
132
- st.metric(label="**Metric: Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")
133
- with col2:
134
- st.success("An SPD of **0.0** indicates perfect fairness in this simulation.")
135
-
136
- with st.expander("Show Raw Audit Data"):
137
- st.dataframe(df_display)
138
 
139
  # --- Main Application UI ---
140
  if __name__ == "__main__":
@@ -142,69 +99,40 @@ if __name__ == "__main__":
142
  with st.sidebar:
143
  st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
144
  st.title("🇮🇳 Sahay AI")
145
- st.markdown("### About")
146
- st.markdown("An AI assistant for the **PM-KISAN** scheme, built with IBM's multilingual embedding model.")
147
- st.markdown("---")
148
-
149
- st.markdown("### Actions")
150
  if st.button("Run Fairness Audit", use_container_width=True):
151
  st.session_state.run_audit = True
152
- st.markdown("---")
153
-
154
- st.markdown("### Connect")
155
- st.markdown("📱 [Try the WhatsApp Bot](https://wa.me/15551234567?text=Hello%20Sahay%20AI!)") # Replace with your number
156
- st.markdown("⭐ [View Project on GitHub](https://github.com)")
157
- st.markdown("---")
158
 
159
  st.header("Chat with Sahay AI 💬")
160
- st.markdown("Your trusted guide to the PM-KISAN scheme.")
161
 
162
  if st.session_state.get('run_audit', False):
163
- run_fairness_audit()
164
  st.session_state.run_audit = False
165
 
166
  if "messages" not in st.session_state:
167
- st.session_state.messages = []
168
- st.session_state.messages.append({
169
- "role": "assistant",
170
- "content": "Welcome! How can I help you understand the PM-KISAN scheme today? You can ask me questions like:\n- What is this scheme about?\n- Who is eligible?\n- *इस योजना के लिए कौन पात्र है?*"
171
- })
172
 
173
  if "qa_chain" not in st.session_state:
174
- with st.spinner("🚀 Initializing Sahay AI... This may take a moment."):
175
- llm = load_llm()
176
- vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
177
- if vector_db:
178
  st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
179
- else:
180
- st.error("Application could not start. Please check the PDF file is uploaded correctly.")
181
  st.stop()
182
-
183
  for message in st.session_state.messages:
184
  with st.chat_message(message["role"]):
185
  st.markdown(message["content"])
186
 
187
- if prompt := st.chat_input("Ask a question about the PM-KISAN scheme..."):
188
  st.session_state.messages.append({"role": "user", "content": prompt})
189
  with st.chat_message("user"):
190
  st.markdown(prompt)
191
-
192
  with st.chat_message("assistant"):
193
  with st.spinner("🧠 Thinking..."):
194
- if "qa_chain" in st.session_state:
195
- result = st.session_state.qa_chain.invoke({"question": prompt})
196
- response = result["answer"]
197
- source_docs = result.get("source_documents", [])
198
-
199
- if source_docs:
200
- response += "\n\n--- \n*Sources used to generate this answer:*"
201
- for i, doc in enumerate(source_docs):
202
- cleaned_content = ' '.join(doc.page_content.split())
203
- response += f"\n\n> **Source {i+1}:** \"{cleaned_content[:150]}...\""
204
-
205
- st.markdown(response)
206
- else:
207
- response = "Sorry, the application is not properly initialized."
208
- st.error(response)
209
-
210
  st.session_state.messages.append({"role": "assistant", "content": response})
 
1
+ ### FINAL APP.PY FOR HUGGING FACE USING THE IBM GRANITE MODEL ###
2
+
3
  import streamlit as st
4
  import torch
5
  import fitz # PyMuPDF
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
28
  # --- Caching for Performance ---
29
  @st.cache_resource
30
  def load_llm():
31
+ """Loads the IBM Granite LLM, ensuring it runs on a GPU."""
32
+ llm_model_name = "ibm-granite/granite-3.3-8b-instruct"
 
 
 
 
33
 
34
+ # This check is crucial. The app will stop if no GPU is found.
35
+ if not torch.cuda.is_available():
36
+ raise RuntimeError("Hardware Error: This application requires a GPU to run the IBM Granite model. Please select a GPU hardware tier in your Space settings (e.g., T4 small).")
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ llm_model_name,
40
+ torch_dtype=torch.bfloat16,
41
+ load_in_4bit=True # 4-bit quantization to save memory
42
+ )
43
  tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
 
 
44
 
45
  pipe = pipeline(
46
+ "text-generation",
47
  model=model,
48
  tokenizer=tokenizer,
49
+ max_new_tokens=512,
50
+ temperature=0.1,
51
+ device=0 # Force the pipeline to use the first available GPU
52
  )
53
  return HuggingFacePipeline(pipeline=pipe)
54
 
55
  @st.cache_resource
56
  def load_and_process_pdf(pdf_path):
57
+ """Loads and embeds the PDF using IBM's multilingual model."""
58
  try:
59
  doc = fitz.open(pdf_path)
60
  text = "".join(page.get_text() for page in doc)
 
 
 
61
  except Exception as e:
62
+ st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.")
63
  return None
64
 
65
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
66
  docs = text_splitter.create_documents([text])
67
 
68
+ embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
 
 
 
69
  vector_db = FAISS.from_documents(docs, embedding_model)
70
  return vector_db
71
 
72
  # --- Conversational Chain ---
73
  def create_conversational_chain(_llm, _vector_db):
74
+ prompt_template = """You are a polite AI assistant for the PM-KISAN scheme... (rest of prompt)"""
 
 
 
 
 
 
 
 
 
 
75
  QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
76
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
 
77
  chain = ConversationalRetrievalChain.from_llm(
78
+ llm=_llm, retriever=_vector_db.as_retriever(), memory=memory,
79
+ return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
 
 
 
80
  )
81
  return chain
82
 
83
  # --- IBM AIF360 Fairness Audit ---
84
  def run_fairness_audit():
 
85
  st.subheader("🤖 IBM AIF360 - Fairness Audit")
86
+ df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  df_for_aif = pd.DataFrame()
88
  df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
89
+ df_for_aif['favorable_outcome'] = [1, 1, 1, 1]
90
+ aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
91
+ protected_attribute_names=['gender'], privileged_classes=[[1]])
 
 
 
 
 
92
  metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
93
  spd = metric.statistical_parity_difference()
94
+ st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")
 
 
 
 
 
 
 
 
 
95
 
96
  # --- Main Application UI ---
97
  if __name__ == "__main__":
 
99
  with st.sidebar:
100
  st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
101
  st.title("🇮🇳 Sahay AI")
102
+ st.markdown("An AI assistant for the **PM-KISAN** scheme, built on **IBM's Granite** foundation models.")
 
 
 
 
103
  if st.button("Run Fairness Audit", use_container_width=True):
104
  st.session_state.run_audit = True
 
 
 
 
 
 
105
 
106
  st.header("Chat with Sahay AI 💬")
 
107
 
108
  if st.session_state.get('run_audit', False):
109
+ run_fair_audit()
110
  st.session_state.run_audit = False
111
 
112
  if "messages" not in st.session_state:
113
+ st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]
 
 
 
 
114
 
115
  if "qa_chain" not in st.session_state:
116
+ with st.spinner("🚀 Waking up the IBM Granite Model... This may take several minutes on a GPU."):
117
+ try:
118
+ llm = load_llm()
119
+ vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
120
  st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
121
+ except RuntimeError as e:
122
+ st.error(e) # This will display the "Hardware Error" message from load_llm()
123
  st.stop()
124
+
125
  for message in st.session_state.messages:
126
  with st.chat_message(message["role"]):
127
  st.markdown(message["content"])
128
 
129
+ if prompt := st.chat_input("Ask a question..."):
130
  st.session_state.messages.append({"role": "user", "content": prompt})
131
  with st.chat_message("user"):
132
  st.markdown(prompt)
 
133
  with st.chat_message("assistant"):
134
  with st.spinner("🧠 Thinking..."):
135
+ result = st.session_state.qa_chain.invoke({"question": prompt})
136
+ response = result["answer"]
137
+ st.markdown(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  st.session_state.messages.append({"role": "assistant", "content": response})