ombhojane commited on
Commit
327ce39
·
verified ·
1 Parent(s): 02bf9c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +368 -0
app.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import json
4
+ import requests
5
+ from google import genai
6
+ from google.genai import types
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ # Set page configuration
12
+ st.set_page_config(
13
+ page_title="Flaix - Financial Assistant",
14
+ page_icon="💰",
15
+ layout="centered"
16
+ )
17
+
18
+ # Initialize Gemini client
19
+ api_key = st.secrets["GOOGLE_API_KEY"]
20
+ genai.Client(api_key=api_key)
21
+
22
+ # Indian Stock Market API base configuration
23
+ INDIAN_API_KEY = st.secrets["FINANCE_KEY"]
24
+ INDIAN_API_BASE_URL = "https://stock.indianapi.in"
25
+
26
+ # Define API endpoints and their parameters
27
+ API_ENDPOINTS = {
28
+ "get_stock_details": {
29
+ "endpoint": "/stock",
30
+ "required_params": ["stock_name"],
31
+ "param_mapping": {"stock_name": "name"},
32
+ "description": "Get details for a specific stock"
33
+ },
34
+ "get_trending_stocks": {
35
+ "endpoint": "/trending",
36
+ "required_params": [],
37
+ "param_mapping": {},
38
+ "description": "Get trending stocks in the market"
39
+ },
40
+ "get_market_news": {
41
+ "endpoint": "/news",
42
+ "required_params": [],
43
+ "param_mapping": {},
44
+ "description": "Get latest stock market news"
45
+ },
46
+ "get_mutual_funds": {
47
+ "endpoint": "/mutual_funds",
48
+ "required_params": [],
49
+ "param_mapping": {},
50
+ "description": "Get mutual funds data"
51
+ },
52
+ "get_ipo_data": {
53
+ "endpoint": "/ipo",
54
+ "required_params": [],
55
+ "param_mapping": {},
56
+ "description": "Get IPO data"
57
+ },
58
+ "get_bse_most_active": {
59
+ "endpoint": "/BSE_most_active",
60
+ "required_params": [],
61
+ "param_mapping": {},
62
+ "description": "Get BSE most active stocks"
63
+ },
64
+ "get_nse_most_active": {
65
+ "endpoint": "/NSE_most_active",
66
+ "required_params": [],
67
+ "param_mapping": {},
68
+ "description": "Get NSE most active stocks"
69
+ },
70
+ "get_historical_data": {
71
+ "endpoint": "/historical_data",
72
+ "required_params": ["stock_name"],
73
+ "optional_params": ["period"],
74
+ "default_values": {"period": "1m", "filter": "default"},
75
+ "param_mapping": {},
76
+ "description": "Get historical data for a stock"
77
+ }
78
+ }
79
+
80
+ # Unified API call function
81
+ def call_indian_api(endpoint, params=None):
82
+ """
83
+ Generic function to call the Indian Stock Market API
84
+
85
+ Args:
86
+ endpoint: API endpoint suffix (e.g., '/stock', '/trending')
87
+ params: Optional parameters for the API call
88
+
89
+ Returns:
90
+ JSON response from the API
91
+ """
92
+ url = f"{INDIAN_API_BASE_URL}{endpoint}"
93
+ headers = {"X-Api-Key": INDIAN_API_KEY}
94
+
95
+ try:
96
+ response = requests.get(url, headers=headers, params=params)
97
+ return response.json()
98
+ except Exception as e:
99
+ return {"error": str(e)}
100
+
101
+ # Function to call API by name
102
+ def call_api_by_name(api_name, **kwargs):
103
+ """
104
+ Call an API by its name from the API_ENDPOINTS dictionary
105
+
106
+ Args:
107
+ api_name: Name of the API to call (key in API_ENDPOINTS)
108
+ **kwargs: Parameters to pass to the API
109
+
110
+ Returns:
111
+ JSON response from the API
112
+ """
113
+ if api_name not in API_ENDPOINTS:
114
+ return {"error": f"Unknown API: {api_name}"}
115
+
116
+ api_info = API_ENDPOINTS[api_name]
117
+ endpoint = api_info["endpoint"]
118
+
119
+ # Check required parameters
120
+ for param in api_info.get("required_params", []):
121
+ if param not in kwargs:
122
+ return {"error": f"Missing required parameter: {param}"}
123
+
124
+ # Apply parameter mapping
125
+ mapped_params = {}
126
+ for param, value in kwargs.items():
127
+ mapped_name = api_info.get("param_mapping", {}).get(param, param)
128
+ mapped_params[mapped_name] = value
129
+
130
+ # Apply default values
131
+ for param, value in api_info.get("default_values", {}).items():
132
+ if param not in mapped_params:
133
+ mapped_params[param] = value
134
+
135
+ return call_indian_api(endpoint, mapped_params)
136
+
137
+ # Improved orchestrator function
138
+ def orchestrator(query):
139
+ """
140
+ Determines if the query requires market data and which API to call
141
+ Returns: (needs_api, api_function, params)
142
+ """
143
+ # Create a more precise prompt for the orchestrator
144
+ orchestrator_prompt = """
145
+ You are an orchestrator for a financial assistant specialized in Indian markets. Your job is to analyze user queries and determine if they need real-time market data.
146
+
147
+ IMPORTANT: Be very precise in your analysis. Only return TRUE for "needs_api" when the query EXPLICITLY asks for current market data, stock prices, or listings.
148
+
149
+ Examples where needs_api should be TRUE:
150
+ - "Show me the most active stocks on NSE today" → get_nse_most_active
151
+ - "What is the current price of Reliance?" → get_stock_details with stock_name="Reliance"
152
+ - "Tell me about trending stocks" → get_trending_stocks
153
+ - "What are the latest IPOs?" → get_ipo_data
154
+
155
+ Examples where needs_api should be FALSE:
156
+ - "What is compound interest?"
157
+ - "How should I start investing?"
158
+ - "What are the tax benefits of PPF?"
159
+ - "Explain mutual funds to me"
160
+
161
+ Available API functions:
162
+ - get_stock_details(stock_name): Get details for a specific stock
163
+ - get_trending_stocks(): Get trending stocks in the market
164
+ - get_market_news(): Get latest stock market news
165
+ - get_mutual_funds(): Get mutual funds data
166
+ - get_ipo_data(): Get IPO data
167
+ - get_bse_most_active(): Get BSE most active stocks
168
+ - get_nse_most_active(): Get NSE most active stocks
169
+ - get_historical_data(stock_name, period="1m"): Get historical data for a stock
170
+
171
+ User query: """ + query + """
172
+
173
+ Respond in JSON format with the following structure:
174
+ {
175
+ "needs_api": true/false,
176
+ "function": "function_name_if_needed",
177
+ "params": {
178
+ "param1": "value1",
179
+ "param2": "value2"
180
+ }
181
+ }
182
+ """
183
+
184
+ # Call Gemini API for orchestration decision
185
+ client = get_gemini_client()
186
+
187
+ # Create content for the orchestrator
188
+ contents = [
189
+ types.Content(
190
+ role="user",
191
+ parts=[
192
+ types.Part.from_text(text=orchestrator_prompt)
193
+ ],
194
+ ),
195
+ ]
196
+
197
+ # Configure generation parameters
198
+ generate_content_config = types.GenerateContentConfig(
199
+ temperature=0.2,
200
+ top_p=0.95,
201
+ top_k=40,
202
+ max_output_tokens=500,
203
+ response_mime_type="text/plain",
204
+ )
205
+
206
+ # Generate content
207
+ response = client.models.generate_content(
208
+ model="gemini-1.5-flash",
209
+ contents=contents,
210
+ config=generate_content_config,
211
+ )
212
+
213
+ # Parse the response
214
+ try:
215
+ decision_text = response.text
216
+ # Extract JSON from the response (it might be wrapped in markdown code blocks)
217
+ if "```json" in decision_text:
218
+ json_str = decision_text.split("```json")[1].split("```")[0].strip()
219
+ elif "```" in decision_text:
220
+ json_str = decision_text.split("```")[1].strip()
221
+ else:
222
+ json_str = decision_text.strip()
223
+
224
+ decision = json.loads(json_str)
225
+ return decision
226
+ except Exception as e:
227
+ print(f"Error parsing orchestrator response: {e}")
228
+ return {"needs_api": False}
229
+
230
+ # Language setting
231
+
232
+ # Financial assistant system prompt
233
+ SYSTEM_PROMPT = f"""You are Flaix, a helpful and knowledgeable financial assistant designed specifically for Indian users. Your purpose is to improve financial literacy and provide guidance on investments in the Indian market.
234
+
235
+ Key responsibilities:
236
+ 1. Explain financial concepts in simple, easy-to-understand language
237
+ 2. Provide information about different investment options available in India (stocks, mutual funds, bonds, PPF, FDs, etc.)
238
+ 3. Help users understand investment risks and returns
239
+ 4. Explain tax implications of different investments in the Indian context
240
+ 5. Guide users on how to start investing based on their goals and risk tolerance
241
+ 6. Answer questions about market trends and financial news in India
242
+ """
243
+
244
+ # Initialize session state for chat history
245
+ if "messages" not in st.session_state:
246
+ st.session_state.messages = [
247
+ {"role": "user", "content": SYSTEM_PROMPT},
248
+ {"role": "model", "content": "Hello! I am Flaix, your financial assistant. You can ask me about investments, financial planning, or any other financial topic."}
249
+ ]
250
+
251
+ # App title and description
252
+ st.title("Flaix - Your Financial Assistant")
253
+ st.markdown("Ask any questions about investing, financial planning, or the Indian financial market.")
254
+
255
+ # Display chat messages
256
+ for message in st.session_state.messages:
257
+ if message["role"] == "user" and message["content"] != SYSTEM_PROMPT:
258
+ with st.chat_message("user"):
259
+ st.write(message["content"])
260
+ elif message["role"] == "model":
261
+ with st.chat_message("assistant"):
262
+ st.write(message["content"])
263
+
264
+ # Chat input
265
+ if prompt := st.chat_input("Ask me anything about finance or investing..."):
266
+ # Add user message to chat history
267
+ st.session_state.messages.append({"role": "user", "content": prompt})
268
+
269
+ # Display user message
270
+ with st.chat_message("user"):
271
+ st.write(prompt)
272
+
273
+ # Display assistant response
274
+ with st.chat_message("assistant"):
275
+ message_placeholder = st.empty()
276
+ full_response = ""
277
+
278
+ try:
279
+ # First, use the orchestrator to determine if we need to call an API
280
+ decision = orchestrator(prompt)
281
+
282
+ # If we need to call an API, do so and add the result to the context
283
+ api_context = ""
284
+ if decision.get("needs_api", False):
285
+ function_name = decision.get("function", "")
286
+ params = decision.get("params", {})
287
+
288
+ message_placeholder.write("Fetching real-time market data...")
289
+
290
+ if function_name in API_ENDPOINTS:
291
+ api_result = call_api_by_name(function_name, **params)
292
+ api_context = f"\nHere is the real-time market data from the Indian Stock Market API:\n{json.dumps(api_result, indent=2)}\n\nPlease use this data to provide an informative response to the user's query."
293
+
294
+ # Get Gemini client
295
+ client = get_gemini_client()
296
+
297
+ # Prepare the user query with API context if available
298
+ user_query = prompt
299
+ if api_context:
300
+ user_query = f"{prompt}\n\n[SYSTEM NOTE: {api_context}]"
301
+
302
+ # Prepare the system message
303
+ system_message = SYSTEM_PROMPT
304
+ if len(st.session_state.messages) > 2: # If we have conversation history
305
+ # Extract previous conversation for context
306
+ conversation_history = ""
307
+ for i in range(1, min(5, len(st.session_state.messages) - 1)): # Get up to 5 previous exchanges
308
+ if st.session_state.messages[i]["role"] == "user" and st.session_state.messages[i]["content"] != SYSTEM_PROMPT:
309
+ conversation_history += f"User: {st.session_state.messages[i]['content']}\n"
310
+ elif st.session_state.messages[i]["role"] == "model":
311
+ conversation_history += f"Assistant: {st.session_state.messages[i]['content']}\n"
312
+
313
+ system_message += f"\n\nPrevious conversation:\n{conversation_history}"
314
+
315
+ # Create content for the LLM
316
+ contents = [
317
+ types.Content(
318
+ role="user",
319
+ parts=[
320
+ types.Part.from_text(text=system_message)
321
+ ],
322
+ ),
323
+ types.Content(
324
+ role="model",
325
+ parts=[
326
+ types.Part.from_text(text="I understand my role as Flaix, a financial assistant for Indian users. I'll provide helpful information about investing and financial planning in simple language.")
327
+ ],
328
+ ),
329
+ types.Content(
330
+ role="user",
331
+ parts=[
332
+ types.Part.from_text(text=user_query)
333
+ ],
334
+ ),
335
+ ]
336
+
337
+ # Configure generation parameters
338
+ generate_content_config = types.GenerateContentConfig(
339
+ temperature=0.7,
340
+ top_p=0.95,
341
+ top_k=40,
342
+ max_output_tokens=8192,
343
+ response_mime_type="text/plain",
344
+ )
345
+
346
+ # Stream the response
347
+ response_stream = client.models.generate_content_stream(
348
+ model="gemini-1.5-flash",
349
+ contents=contents,
350
+ config=generate_content_config,
351
+ )
352
+
353
+ # Process streaming response
354
+ for chunk in response_stream:
355
+ if hasattr(chunk, 'text'):
356
+ full_response += chunk.text
357
+ message_placeholder.write(full_response + "▌")
358
+
359
+ # Final update without cursor
360
+ message_placeholder.write(full_response)
361
+
362
+ except Exception as e:
363
+ st.error(f"Error: {str(e)}")
364
+ full_response = "I'm sorry, I encountered an error. Please try again later."
365
+ message_placeholder.write(full_response)
366
+
367
+ # Add assistant response to chat history
368
+ st.session_state.messages.append({"role": "model", "content": full_response})