AnilNiraula commited on
Commit
5c97638
·
verified ·
1 Parent(s): 862e37a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -60
app.py CHANGED
@@ -1,112 +1,125 @@
1
  import logging
2
  import os
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import gradio as gr
 
 
6
 
7
  # Set up logging
8
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
9
  logger = logging.getLogger(__name__)
10
 
11
- # Define device (force CPU for Spaces compatibility)
12
  device = torch.device("cpu")
13
  logger.info(f"Using device: {device}")
14
 
15
- # Response cache with step-by-step advice
16
  response_cache = {
17
  "hi": "Hello! I'm your financial advisor. How can I help with investing?",
18
  "hello": "Hello! I'm your financial advisor. How can I help with investing?",
19
  "hey": "Hi there! Ready to discuss investment goals?",
20
  "hi, give me step-by-step investing advice": (
21
  "Here’s a step-by-step guide to start investing:\n"
22
- "1. **Open a Brokerage Account**: If you’re 18 or older, sign up with a platform like Fidelity, Vanguard, or Robinhood.\n"
23
- "2. **Deposit Funds**: Add an initial amount you can afford, such as $100, after securing an emergency fund.\n"
24
- "3. **Research and Buy**: Choose a stock, ETF (e.g., VOO for S&P 500), or index fund based on research from Yahoo Finance or Morningstar.\n"
25
- "4. **Monitor Investments**: Check your portfolio regularly and enable dividend reinvesting for compounding returns.\n"
26
- "5. **Use Dollar-Cost Averaging**: Invest a fixed amount (e.g., $100 monthly) consistently to reduce market timing risks.\n"
27
- "6. **Diversify**: Spread investments across sectors to manage risk.\n"
28
- "Consult a certified financial planner for personalized advice."
29
  ),
30
  "hi, pretend you are a financial advisor. now tell me how can i start investing in stock market?": (
31
- "Here’s a guide to start investing in the stock market:\n"
32
- "1. **Learn**: Use Investopedia or 'The Intelligent Investor' by Benjamin Graham.\n"
33
- "2. **Goals**: Set objectives (e.g., retirement) and assess risk tolerance.\n"
34
- "3. **Brokerage**: Choose Fidelity, Vanguard, or Robinhood.\n"
35
- "4. **Investments**: Start with ETFs (e.g., VOO) or mutual funds.\n"
36
- "5. **Strategy**: Use dollar-cost averaging ($100-$500 monthly).\n"
37
- "6. **Risks**: Diversify and monitor.\n"
38
- "Consult a certified financial planner."
39
  ),
40
  "do you have a list of companies you recommend?": (
41
- "I cannot recommend specific companies without current data. Consider ETFs like VOO (S&P 500) or QQQ (tech). "
42
- "Research sectors like technology (e.g., Apple) or healthcare (e.g., Johnson & Johnson) on Yahoo Finance. "
43
  "Consult a financial planner."
44
  ),
45
  "how do i start investing in stocks?": (
46
- "Educate yourself with Investopedia or 'The Intelligent Investor.' Set goals and assess risk tolerance. "
47
- "Open a brokerage account with Fidelity or Vanguard and start with ETFs (e.g., VOO). Consult a financial planner."
48
  ),
49
  "what's the difference between stocks and bonds?": (
50
- "Stocks offer ownership in a company with growth potential but higher risk. Bonds are loans to companies/governments, "
51
- "offering steady interest with lower risk. Diversify with both for balance."
52
  ),
53
  "how much should i invest?": (
54
- "Invest what you can afford after expenses and an emergency fund (3-6 months’ savings). Start with $100-$500 monthly "
55
- "in ETFs like VOO using dollar-cost averaging. Consult a financial planner."
56
  ),
57
  "what is dollar-cost averaging?": (
58
- "Dollar-cost averaging is investing a fixed amount regularly (e.g., $100 monthly) in assets like ETFs, "
59
  "reducing risk by spreading purchases over time."
60
  ),
61
  "give me few investing idea": (
62
- "Here are some investing ideas:\n"
63
- "1. Open a brokerage account if you are 18 or older (e.g., Fidelity, Vanguard).\n"
64
- "2. Deposit an initial amount you can afford (e.g., $100).\n"
65
- "3. Buy a researched stock, ETF (e.g., VOO), or index fund.\n"
66
- "4. Check your investments regularly and enable dividend reinvesting if desired.\n"
67
- "5. Use dollar-cost averaging to buy the same investment regularly (e.g., monthly).\n"
68
- "Consult a financial planner for personalized advice."
 
 
 
 
 
 
 
 
 
69
  )
70
  }
71
 
72
  # Load model and tokenizer
73
- model_name = "facebook/opt-350m"
74
  try:
75
  logger.info(f"Loading tokenizer for {model_name}")
76
  tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
77
  logger.info(f"Loading model {model_name}")
78
  model = AutoModelForCausalLM.from_pretrained(
79
  model_name,
80
- torch_dtype=torch.float16
 
 
81
  ).to(device)
82
  except Exception as e:
83
  logger.error(f"Error loading model/tokenizer: {e}")
84
  raise
85
 
86
- # Pre-tokenize prompt prefix with few-shot example
87
  prompt_prefix = (
88
- "You are a financial advisor. Provide concise, actionable advice in a numbered list for step-by-step or idea prompts. "
89
- "Avoid repetition and vague statements. Use varied, specific steps.\n\n"
90
- "Example:\n"
91
- "Q: Give me step-by-step investing advice\n"
92
- "A: 1. Open a brokerage account with Fidelity or Vanguard if 18 or older.\n"
93
- "2. Deposit an affordable amount, like $100, after building an emergency fund.\n"
94
- "3. Research and buy an ETF like VOO using Yahoo Finance data.\n"
95
- "4. Check investments monthly and enable dividend reinvesting.\n"
96
- "5. Invest regularly with dollar-cost averaging to reduce risk.\n\n"
97
  "Q: "
98
  )
99
  prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
100
 
 
 
 
 
 
101
  # Define chat function
102
  def chat_with_model(message, history=None):
103
  try:
104
  logger.info(f"Processing message: {message}")
105
- # Normalize input for cache
106
  cache_key = message.lower().strip()
107
- if cache_key in response_cache:
108
- logger.info("Cache hit")
109
- return response_cache[cache_key]
 
 
110
 
111
  # Skip model for short prompts
112
  if len(message.strip()) <= 5:
@@ -117,20 +130,21 @@ def chat_with_model(message, history=None):
117
  full_prompt = prompt_prefix + message + "\nA:"
118
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
119
 
120
- # Generate response
121
- with torch.no_grad():
122
  outputs = model.generate(
123
  **inputs,
124
- max_new_tokens=100,
125
- min_length=20,
126
  do_sample=True,
127
  temperature=0.7,
128
  top_p=0.9,
129
- no_repeat_ngram_size=2,
130
  pad_token_id=tokenizer.eos_token_id
131
  )
132
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
133
  logger.info("Generated response")
 
134
  return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
135
  except Exception as e:
136
  logger.error(f"Error generating response: {e}")
@@ -140,15 +154,15 @@ def chat_with_model(message, history=None):
140
  logger.info("Initializing Gradio interface")
141
  interface = gr.ChatInterface(
142
  fn=chat_with_model,
143
- title="Financial Advisor Chatbot (OPT-350m)",
144
- description="Ask about investing! Powered by Meta AI's OPT-350m. Fast, detailed answers.",
145
  examples=[
146
  "Hi",
147
  "Hi, give me step-by-step investing advice",
148
  "Give me few investing idea",
 
149
  "Do you have a list of companies you recommend?",
150
- "What's the difference between stocks and bonds?",
151
- "How much should I invest?"
152
  ]
153
  )
154
 
@@ -156,7 +170,7 @@ interface = gr.ChatInterface(
156
  if __name__ == "__main__" and not os.getenv("HF_SPACE"):
157
  logger.info("Launching Gradio interface locally")
158
  try:
159
- interface.launch(share=False, debug=True)
160
  except Exception as e:
161
  logger.error(f"Error launching interface: {e}")
162
  raise
 
1
  import logging
2
  import os
 
3
  import torch
4
  import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import difflib
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Define device (force CPU for Spaces free tier)
13
  device = torch.device("cpu")
14
  logger.info(f"Using device: {device}")
15
 
16
+ # Response cache with expanded entries
17
  response_cache = {
18
  "hi": "Hello! I'm your financial advisor. How can I help with investing?",
19
  "hello": "Hello! I'm your financial advisor. How can I help with investing?",
20
  "hey": "Hi there! Ready to discuss investment goals?",
21
  "hi, give me step-by-step investing advice": (
22
  "Here’s a step-by-step guide to start investing:\n"
23
+ "1. Open a brokerage account (e.g., Fidelity, Vanguard) if 18 or older.\n"
24
+ "2. Deposit an affordable amount, like $100, after an emergency fund.\n"
25
+ "3. Research and buy an ETF (e.g., VOO) using Yahoo Finance.\n"
26
+ "4. Monitor monthly and enable dividend reinvesting.\n"
27
+ "5. Use dollar-cost averaging ($100 monthly) to reduce risk.\n"
28
+ "6. Diversify across sectors.\n"
29
+ "Consult a financial planner."
30
  ),
31
  "hi, pretend you are a financial advisor. now tell me how can i start investing in stock market?": (
32
+ "Here’s a guide to start investing:\n"
33
+ "1. Learn from Investopedia or 'The Intelligent Investor.'\n"
34
+ "2. Set goals (e.g., retirement) and assess risk.\n"
35
+ "3. Choose a brokerage (Fidelity, Vanguard).\n"
36
+ "4. Start with ETFs (e.g., VOO) or mutual funds.\n"
37
+ "5. Use dollar-cost averaging ($100-$500 monthly).\n"
38
+ "6. Diversify and monitor.\n"
39
+ "Consult a financial planner."
40
  ),
41
  "do you have a list of companies you recommend?": (
42
+ "I can’t recommend specific companies without data. Try ETFs like VOO (S&P 500) or QQQ (tech). "
43
+ "Research technology (e.g., Apple) or healthcare (e.g., Johnson & Johnson) on Yahoo Finance. "
44
  "Consult a financial planner."
45
  ),
46
  "how do i start investing in stocks?": (
47
+ "Learn from Investopedia. Set goals and assess risk. Open a brokerage account (Fidelity, Vanguard) "
48
+ "and start with ETFs (e.g., VOO). Consult a financial planner."
49
  ),
50
  "what's the difference between stocks and bonds?": (
51
+ "Stocks are company ownership with high risk and growth potential. Bonds are loans to companies/governments "
52
+ "with lower risk and steady interest. Diversify for balance."
53
  ),
54
  "how much should i invest?": (
55
+ "Invest what you can afford after expenses and an emergency fund. Start with $100-$500 monthly "
56
+ "in ETFs (e.g., VOO) using dollar-cost averaging. Consult a financial planner."
57
  ),
58
  "what is dollar-cost averaging?": (
59
+ "Dollar-cost averaging is investing a fixed amount regularly (e.g., $100 monthly) in ETFs, "
60
  "reducing risk by spreading purchases over time."
61
  ),
62
  "give me few investing idea": (
63
+ "Here are investing ideas:\n"
64
+ "1. Open a brokerage account (e.g., Fidelity) if 18 or older.\n"
65
+ "2. Deposit $100 or what you can afford.\n"
66
+ "3. Buy a researched ETF (e.g., VOO) or index fund.\n"
67
+ "4. Check regularly and enable dividend reinvesting.\n"
68
+ "5. Use dollar-cost averaging (e.g., monthly buys).\n"
69
+ "Consult a financial planner."
70
+ ),
71
+ "give me investing tips": (
72
+ "Here are investing tips:\n"
73
+ "1. Educate yourself with Investopedia or books.\n"
74
+ "2. Open a brokerage account (e.g., Vanguard).\n"
75
+ "3. Start small with ETFs like VOO.\n"
76
+ "4. Invest regularly using dollar-cost averaging.\n"
77
+ "5. Diversify to manage risk.\n"
78
+ "Consult a financial planner."
79
  )
80
  }
81
 
82
  # Load model and tokenizer
83
+ model_name = "distilgpt2" # Smaller model for CPU
84
  try:
85
  logger.info(f"Loading tokenizer for {model_name}")
86
  tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
87
  logger.info(f"Loading model {model_name}")
88
  model = AutoModelForCausalLM.from_pretrained(
89
  model_name,
90
+ torch_dtype=torch.float16,
91
+ low_cpu_mem_usage=True,
92
+ load_in_4bit=True # 4-bit quantization
93
  ).to(device)
94
  except Exception as e:
95
  logger.error(f"Error loading model/tokenizer: {e}")
96
  raise
97
 
98
+ # Pre-tokenize prompt prefix
99
  prompt_prefix = (
100
+ "You are a financial advisor. Provide concise, numbered list advice for investing prompts. "
101
+ "Avoid repetition and vague statements.\n\n"
102
+ "Example: Q: Give investing tips\nA: 1. Open a brokerage account.\n2. Start with ETFs like VOO.\n3. Use dollar-cost averaging.\n\n"
 
 
 
 
 
 
103
  "Q: "
104
  )
105
  prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
106
 
107
+ # Fuzzy matching for cache
108
+ def get_closest_cache_key(message, cache_keys, threshold=0.9):
109
+ matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
110
+ return matches[0] if matches else None
111
+
112
  # Define chat function
113
  def chat_with_model(message, history=None):
114
  try:
115
  logger.info(f"Processing message: {message}")
116
+ # Normalize and check cache
117
  cache_key = message.lower().strip()
118
+ cache_keys = list(response_cache.keys())
119
+ closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
120
+ if closest_key:
121
+ logger.info(f"Cache hit for: {closest_key}")
122
+ return response_cache[closest_key]
123
 
124
  # Skip model for short prompts
125
  if len(message.strip()) <= 5:
 
130
  full_prompt = prompt_prefix + message + "\nA:"
131
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
132
 
133
+ # Generate response with mixed precision
134
+ with torch.cpu.amp.autocast(), torch.no_grad():
135
  outputs = model.generate(
136
  **inputs,
137
+ max_new_tokens=80,
138
+ min_length=15,
139
  do_sample=True,
140
  temperature=0.7,
141
  top_p=0.9,
142
+ repetition_penalty=1.2,
143
  pad_token_id=tokenizer.eos_token_id
144
  )
145
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
146
  logger.info("Generated response")
147
+ torch.cuda.empty_cache() # Clear memory
148
  return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
149
  except Exception as e:
150
  logger.error(f"Error generating response: {e}")
 
154
  logger.info("Initializing Gradio interface")
155
  interface = gr.ChatInterface(
156
  fn=chat_with_model,
157
+ title="Financial Advisor Chatbot (DistilGPT2)",
158
+ description="Ask about investing! Fast, detailed answers on CPU.",
159
  examples=[
160
  "Hi",
161
  "Hi, give me step-by-step investing advice",
162
  "Give me few investing idea",
163
+ "Give me investing tips",
164
  "Do you have a list of companies you recommend?",
165
+ "What's the difference between stocks and bonds?"
 
166
  ]
167
  )
168
 
 
170
  if __name__ == "__main__" and not os.getenv("HF_SPACE"):
171
  logger.info("Launching Gradio interface locally")
172
  try:
173
+ interface.launch(share=False, debug=True, queue=False)
174
  except Exception as e:
175
  logger.error(f"Error launching interface: {e}")
176
  raise