AnilNiraula commited on
Commit
cdcdcbf
·
verified ·
1 Parent(s): 30ff85f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -72
app.py CHANGED
@@ -1,90 +1,134 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import torch
3
  import gradio as gr
 
 
 
 
 
 
4
 
5
- # Define device
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- # Response cache
9
- response_cache = {
10
- "Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?": (
11
- "As a financial advisor, here’s a guide to start investing in the stock market:\n"
12
- "1. **Learn**: Use Investopedia or “The Intelligent Investor” by Benjamin Graham.\n"
13
- "2. **Goals**: Set objectives (e.g., retirement) and assess risk tolerance.\n"
14
- "3. **Brokerage**: Choose Fidelity (low fees), Vanguard (index funds like VTI), or Robinhood (commission-free).\n"
15
- "4. **Investments**: Start with ETFs (e.g., VOO for S&P 500) or mutual funds.\n"
16
- "5. **Strategy**: Use dollar-cost averaging with $100-$500 monthly.\n"
17
- "6. **Risks**: Diversify and monitor.\n"
18
- "Consult a certified financial planner."
19
- ),
20
- "do you have a list of companies you recommend?": (
21
- "I cannot recommend specific companies without current market data. Instead, consider ETFs like VOO (S&P 500) or QQQ (tech-focused) for broad exposure. "
22
- "For stocks, research sectors like technology (e.g., Apple, Microsoft) or consumer goods (e.g., Procter & Gamble) using Yahoo Finance or Morningstar. "
23
- "Consult a certified financial planner."
24
- ),
25
- "can you provide me a list of companies you recommend?": (
26
- "I cannot provide specific company recommendations without up-to-date market analysis. For safer investments, consider ETFs like VOO (S&P 500) or QQQ (tech-focused). "
27
- "If interested in stocks, explore stable companies in technology (e.g., Apple, Microsoft) or healthcare (e.g., Johnson & Johnson) using Yahoo Finance. "
28
- "Always consult a financial planner for tailored advice."
29
- ),
30
- "You have a list of companies you recommend?": (
31
- "I cannot recommend specific companies without current market data. Instead, consider ETFs like VOO (S&P 500) or QQQ (tech-focused) for broad exposure. "
32
- "For stocks, research sectors like technology (e.g., Apple, Microsoft) or consumer goods (e.g., Procter & Gamble) using Yahoo Finance or Morningstar. "
33
- "Consult a certified financial planner."
34
- )
35
- }
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Load model with optimizations
38
- model_name = "distilgpt2"
39
  try:
40
- tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
41
- tokenizer.pad_token = tokenizer.eos_token # Ensure pad token is set
42
- model = AutoModelForCausalLM.from_pretrained(
43
- model_name,
44
- device_map="auto",
45
- torch_dtype=torch.float16,
46
- low_cpu_mem_usage=True # Optimize memory usage
47
- ).to(device)
48
- model.eval() # Set model to evaluation mode for faster inference
49
  except Exception as e:
50
- print(f"Error loading distilgpt2: {e}")
51
  exit()
52
 
53
- # Define chat function
54
- def chat_with_model(message, history=None): # Ignore history
 
 
55
  try:
56
- if not isinstance(message, str):
57
- return "Error: User input must be a string"
58
- # Normalize message for cache lookup (case-insensitive, strip whitespace)
59
- message = message.strip().lower()
60
- for cached_message, response in response_cache.items():
61
- if cached_message.lower() == message:
62
- return response
63
- # Simplified prompt
64
- full_prompt = (
65
- "Financial advisor: Answer directly about stock market investments. "
66
- "No specific company picks without data; suggest ETFs or general advice. "
67
- f"User: {message}\nAssistant:"
 
 
68
  )
69
- inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=256).to(device)
70
- with torch.no_grad(): # Disable gradient computation for faster inference
71
- outputs = model.generate(
72
- **inputs,
73
- max_new_tokens=50, # Increased slightly for better responses
74
- do_sample=False, # Greedy decoding for speed
75
- num_beams=1, # Disable beam search for faster generation
76
- pad_token_id=tokenizer.eos_token_id
77
- )
78
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
- return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
80
  except Exception as e:
 
81
  return f"Error generating response: {str(e)}"
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Create Gradio interface
84
  interface = gr.ChatInterface(
85
- fn=chat_with_model,
86
- title="Financial Advisor Chatbot (DistilGPT-2)",
87
- description="Ask for advice on starting to invest in the stock market! Powered by DistilGPT-2. Provides single, direct answers without conversation history.",
 
 
 
 
88
  examples=[
89
  "Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?",
90
  "You have a list of companies you recommend?"
 
1
+ import os
 
2
  import gradio as gr
3
+ import hashlib
4
+ import asyncio
5
+ import pickle
6
+ import time
7
+ from openai import AsyncOpenAI
8
+ from functools import lru_cache
9
 
10
+ # Persistent cache configuration
11
+ CACHE_FILE = "response_cache.pkl"
12
 
13
+ def load_cache():
14
+ try:
15
+ with open(CACHE_FILE, "rb") as f:
16
+ return pickle.load(f)
17
+ except:
18
+ return {
19
+ hashlib.md5("Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?".lower().encode()).hexdigest(): (
20
+ "As a financial advisor, here’s a guide to start investing in the stock market:\n"
21
+ "1. **Learn**: Use Investopedia or “The Intelligent Investor” by Benjamin Graham.\n"
22
+ "2. **Goals**: Set objectives (e.g., retirement) and assess risk tolerance.\n"
23
+ "3. **Brokerage**: Choose Fidelity (low fees), Vanguard (index funds like VTI), or Robinhood (commission-free).\n"
24
+ "4. **Investments**: Start with ETFs (e.g., VOO for S&P 500) or mutual funds.\n"
25
+ "5. **Strategy**: Use dollar-cost averaging with $100-$500 monthly.\n"
26
+ "6. **Risks**: Diversify and monitor.\n"
27
+ "Consult a certified financial planner."
28
+ ),
29
+ hashlib.md5("do you have a list of companies you recommend?".lower().encode()).hexdigest(): (
30
+ "I cannot recommend specific companies without current market data. Instead, consider ETFs like VOO (S&P 500) or QQQ (tech-focused) for broad exposure. "
31
+ "For stocks, research sectors like technology (e.g., Apple, Microsoft) or consumer goods (e.g., Procter & Gamble) using Yahoo Finance or Morningstar. "
32
+ "Consult a certified financial planner."
33
+ ),
34
+ hashlib.md5("can you provide me a list of companies you recommend?".lower().encode()).hexdigest(): (
35
+ "I cannot provide specific company recommendations without up-to-date market analysis. For safer investments, consider ETFs like VOO (S&P 500) or QQQ (tech-focused). "
36
+ "If interested in stocks, explore stable companies in technology (e.g., Apple, Microsoft) or healthcare (e.g., Johnson & Johnson) using Yahoo Finance. "
37
+ "Always consult a financial planner for tailored advice."
38
+ ),
39
+ hashlib.md5("You have a list of companies you recommend?".lower().encode()).hexdigest(): (
40
+ "I cannot recommend specific companies without current market data. Instead, consider ETFs like VOO (S&P 500) or QQQ (tech-focused) for broad exposure. "
41
+ "For stocks, research sectors like technology (e.g., Apple, Microsoft) or consumer goods (e.g., Procter & Gamble) using Yahoo Finance or Morningstar. "
42
+ "Consult a certified financial planner."
43
+ )
44
+ }
45
+
46
+ def save_cache(cache):
47
+ with open(CACHE_FILE, "wb") as f:
48
+ pickle.dump(cache, f)
49
+
50
+ # Initialize response cache
51
+ response_cache = load_cache()
52
 
53
+ # Initialize Grok 3 API async client
 
54
  try:
55
+ client = AsyncOpenAI(
56
+ api_key=os.getenv("XAI_API_KEY"), # Ensure API key is set in environment variables
57
+ base_url="https://api.x.ai/v1"
58
+ )
 
 
 
 
 
59
  except Exception as e:
60
+ print(f"Error initializing Grok 3 API client: {e}")
61
  exit()
62
 
63
+ # Cache API responses with increased size
64
+ @lru_cache(maxsize=500)
65
+ async def fetch_grok_response(prompt: str, model: str) -> str:
66
+ start_time = time.time()
67
  try:
68
+ stream = await client.chat.completions.create(
69
+ model=model,
70
+ messages=[
71
+ {"role": "system", "content": (
72
+ "You are a financial advisor. Provide concise stock market investment advice. "
73
+ "Avoid specific company recommendations without data; suggest ETFs or strategies. "
74
+ "Advise consulting a certified financial planner."
75
+ )},
76
+ {"role": "user", "content": prompt}
77
+ ],
78
+ temperature=0.2,
79
+ max_tokens=150,
80
+ top_p=1.0,
81
+ stream=True
82
  )
83
+ response = ""
84
+ async for chunk in stream:
85
+ if chunk.choices[0].delta.content:
86
+ response += chunk.choices[0].delta.content
87
+ print(f"API call took {time.time() - start_time:.2f} seconds")
88
+ return response.strip()
 
 
 
 
 
89
  except Exception as e:
90
+ print(f"API call took {time.time() - start_time:.2f} seconds")
91
  return f"Error generating response: {str(e)}"
92
 
93
+ # Define async chat function
94
+ async def chat_with_model(message, history=None): # Ignore history
95
+ try:
96
+ if not isinstance(message, str):
97
+ return "Error: User input must be a string"
98
+
99
+ # Normalize message and check cache
100
+ message_normalized = message.strip().lower()
101
+ message_hash = hashlib.md5(message_normalized.encode()).hexdigest()
102
+ if message_hash in response_cache:
103
+ return response_cache[message_hash]
104
+
105
+ # Use smaller model for faster responses
106
+ model = "grok-3-mini"
107
+ response = await fetch_grok_response(message, model)
108
+
109
+ # Update and save cache for new responses
110
+ if not response.startswith("Error"):
111
+ response_cache[message_hash] = response
112
+ save_cache(response_cache)
113
+
114
+ return response
115
+
116
+ except Exception as e:
117
+ return f"Error in chat processing: {str(e)}"
118
+
119
+ # Wrapper for Gradio compatibility
120
+ def chat_with_model_sync(message, history=None):
121
+ return asyncio.run(chat_with_model(message, history))
122
+
123
  # Create Gradio interface
124
  interface = gr.ChatInterface(
125
+ fn=chat_with_model_sync,
126
+ title="Financial Advisor Chatbot (Grok 3 Mini)",
127
+ description=(
128
+ "Ask for advice on starting to invest in the stock market! Powered by xAI's Grok 3 Mini API for faster responses. "
129
+ "Provides single, direct answers without conversation history. "
130
+ "Monitor API rate limits in the xAI Developer Console (console.x.ai)."
131
+ ),
132
  examples=[
133
  "Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?",
134
  "You have a list of companies you recommend?"