AnilNiraula commited on
Commit
30ff85f
·
verified ·
1 Parent(s): 475e80c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -20
app.py CHANGED
@@ -34,17 +34,20 @@ response_cache = {
34
  )
35
  }
36
 
37
- # Load model
38
  model_name = "distilgpt2"
39
  try:
40
  tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_name,
43
  device_map="auto",
44
- torch_dtype=torch.float16
 
45
  ).to(device)
 
46
  except Exception as e:
47
- print(f"Error loading OPT-350m: {e}")
48
  exit()
49
 
50
  # Define chat function
@@ -52,22 +55,26 @@ def chat_with_model(message, history=None): # Ignore history
52
  try:
53
  if not isinstance(message, str):
54
  return "Error: User input must be a string"
55
- if message in response_cache:
56
- return response_cache[message]
 
 
 
 
57
  full_prompt = (
58
- "You are a financial advisor with expertise in stock market investments. "
59
- "Provide accurate, detailed, and actionable advice in a single response. "
60
- "Do not rely on prior conversation context. "
61
- "If you cannot provide specific recommendations (e.g., individual companies), "
62
- "explain why and offer general guidance or alternative suggestions.\n"
63
- "User: {message}\nAssistant:"
64
- ).format(message=message)
65
- inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
66
- outputs = model.generate(
67
- **inputs,
68
- max_new_tokens=30,
69
- do_sample=False, # Enables greedy decoding
70
- pad_token_id=tokenizer.eos_token_id)
71
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
  return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
73
  except Exception as e:
@@ -76,8 +83,8 @@ def chat_with_model(message, history=None): # Ignore history
76
  # Create Gradio interface
77
  interface = gr.ChatInterface(
78
  fn=chat_with_model,
79
- title="Financial Advisor Chatbot (OPT-350m)",
80
- description="Ask for advice on starting to invest in the stock market! Powered by Meta AI's OPT-350m. Provides single, direct answers without conversation history.",
81
  examples=[
82
  "Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?",
83
  "You have a list of companies you recommend?"
 
34
  )
35
  }
36
 
37
+ # Load model with optimizations
38
  model_name = "distilgpt2"
39
  try:
40
  tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
41
+ tokenizer.pad_token = tokenizer.eos_token # Ensure pad token is set
42
  model = AutoModelForCausalLM.from_pretrained(
43
  model_name,
44
  device_map="auto",
45
+ torch_dtype=torch.float16,
46
+ low_cpu_mem_usage=True # Optimize memory usage
47
  ).to(device)
48
+ model.eval() # Set model to evaluation mode for faster inference
49
  except Exception as e:
50
+ print(f"Error loading distilgpt2: {e}")
51
  exit()
52
 
53
  # Define chat function
 
55
  try:
56
  if not isinstance(message, str):
57
  return "Error: User input must be a string"
58
+ # Normalize message for cache lookup (case-insensitive, strip whitespace)
59
+ message = message.strip().lower()
60
+ for cached_message, response in response_cache.items():
61
+ if cached_message.lower() == message:
62
+ return response
63
+ # Simplified prompt
64
  full_prompt = (
65
+ "Financial advisor: Answer directly about stock market investments. "
66
+ "No specific company picks without data; suggest ETFs or general advice. "
67
+ f"User: {message}\nAssistant:"
68
+ )
69
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=256).to(device)
70
+ with torch.no_grad(): # Disable gradient computation for faster inference
71
+ outputs = model.generate(
72
+ **inputs,
73
+ max_new_tokens=50, # Increased slightly for better responses
74
+ do_sample=False, # Greedy decoding for speed
75
+ num_beams=1, # Disable beam search for faster generation
76
+ pad_token_id=tokenizer.eos_token_id
77
+ )
78
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
  return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
80
  except Exception as e:
 
83
  # Create Gradio interface
84
  interface = gr.ChatInterface(
85
  fn=chat_with_model,
86
+ title="Financial Advisor Chatbot (DistilGPT-2)",
87
+ description="Ask for advice on starting to invest in the stock market! Powered by DistilGPT-2. Provides single, direct answers without conversation history.",
88
  examples=[
89
  "Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?",
90
  "You have a list of companies you recommend?"