AnilNiraula commited on
Commit
ae72622
·
verified ·
1 Parent(s): d50fc60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -16
app.py CHANGED
@@ -76,36 +76,45 @@ response_cache = {
76
  "4. Invest regularly using dollar-cost averaging.\n"
77
  "5. Diversify to manage risk.\n"
78
  "Consult a financial planner."
 
 
 
 
 
 
 
 
 
79
  )
80
  }
81
 
82
  # Load model and tokenizer
83
- model_name = "distilgpt2" # Smaller model for CPU
84
  try:
85
  logger.info(f"Loading tokenizer for {model_name}")
86
  tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
87
  logger.info(f"Loading model {model_name}")
88
- model = AutoModelForCausalLM.from_pretrained(
89
- model_name,
90
- torch_dtype=torch.float16,
91
- low_cpu_mem_usage=True,
92
- load_in_4bit=True # 4-bit quantization
93
- ).to(device)
94
  except Exception as e:
95
  logger.error(f"Error loading model/tokenizer: {e}")
96
- raise
97
 
98
  # Pre-tokenize prompt prefix
99
  prompt_prefix = (
100
- "You are a financial advisor. Provide concise, numbered list advice for investing prompts. "
101
- "Avoid repetition and vague statements.\n\n"
102
- "Example: Q: Give investing tips\nA: 1. Open a brokerage account.\n2. Start with ETFs like VOO.\n3. Use dollar-cost averaging.\n\n"
103
  "Q: "
104
  )
105
  prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
106
 
107
  # Fuzzy matching for cache
108
- def get_closest_cache_key(message, cache_keys, threshold=0.9):
109
  matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
110
  return matches[0] if matches else None
111
 
@@ -130,11 +139,11 @@ def chat_with_model(message, history=None):
130
  full_prompt = prompt_prefix + message + "\nA:"
131
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
132
 
133
- # Generate response with mixed precision
134
  with torch.cpu.amp.autocast(), torch.no_grad():
135
  outputs = model.generate(
136
  **inputs,
137
- max_new_tokens=80,
138
  min_length=15,
139
  do_sample=True,
140
  temperature=0.7,
@@ -161,8 +170,8 @@ interface = gr.ChatInterface(
161
  "Hi, give me step-by-step investing advice",
162
  "Give me few investing idea",
163
  "Give me investing tips",
164
- "Do you have a list of companies you recommend?",
165
- "What's the difference between stocks and bonds?"
166
  ]
167
  )
168
 
 
76
  "4. Invest regularly using dollar-cost averaging.\n"
77
  "5. Diversify to manage risk.\n"
78
  "Consult a financial planner."
79
+ ),
80
+ "how to start investing": (
81
+ "Here’s how to start investing:\n"
82
+ "1. Study basics on Investopedia.\n"
83
+ "2. Open a brokerage account (e.g., Fidelity).\n"
84
+ "3. Deposit $100 or more after securing savings.\n"
85
+ "4. Buy an ETF like VOO after research.\n"
86
+ "5. Invest monthly with dollar-cost averaging.\n"
87
+ "Consult a financial planner."
88
  )
89
  }
90
 
91
  # Load model and tokenizer
92
+ model_name = "distilgpt2"
93
  try:
94
  logger.info(f"Loading tokenizer for {model_name}")
95
  tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
96
  logger.info(f"Loading model {model_name}")
97
+ with torch.no_grad():
98
+ model = AutoModelForCausalLM.from_pretrained(
99
+ model_name,
100
+ torch_dtype=torch.float16,
101
+ low_cpu_mem_usage=True
102
+ ).to(device)
103
  except Exception as e:
104
  logger.error(f"Error loading model/tokenizer: {e}")
105
+ raise RuntimeError(f"Failed to load model: {str(e)}")
106
 
107
  # Pre-tokenize prompt prefix
108
  prompt_prefix = (
109
+ "You are a financial advisor. Provide numbered list advice for investing prompts. "
110
+ "Avoid repetition.\n\n"
111
+ "Example: Q: Give investing tips\nA: 1. Open a brokerage.\n2. Buy ETFs like VOO.\n3. Use dollar-cost averaging.\n\n"
112
  "Q: "
113
  )
114
  prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
115
 
116
  # Fuzzy matching for cache
117
+ def get_closest_cache_key(message, cache_keys, threshold=0.85):
118
  matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
119
  return matches[0] if matches else None
120
 
 
139
  full_prompt = prompt_prefix + message + "\nA:"
140
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
141
 
142
+ # Generate response
143
  with torch.cpu.amp.autocast(), torch.no_grad():
144
  outputs = model.generate(
145
  **inputs,
146
+ max_new_tokens=60,
147
  min_length=15,
148
  do_sample=True,
149
  temperature=0.7,
 
170
  "Hi, give me step-by-step investing advice",
171
  "Give me few investing idea",
172
  "Give me investing tips",
173
+ "How to start investing",
174
+ "Do you have a list of companies you recommend?"
175
  ]
176
  )
177