Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -34,17 +34,20 @@ response_cache = {
|
|
34 |
)
|
35 |
}
|
36 |
|
37 |
-
# Load model
|
38 |
model_name = "distilgpt2"
|
39 |
try:
|
40 |
tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
|
|
|
41 |
model = AutoModelForCausalLM.from_pretrained(
|
42 |
model_name,
|
43 |
device_map="auto",
|
44 |
-
torch_dtype=torch.float16
|
|
|
45 |
).to(device)
|
|
|
46 |
except Exception as e:
|
47 |
-
print(f"Error loading
|
48 |
exit()
|
49 |
|
50 |
# Define chat function
|
@@ -52,22 +55,26 @@ def chat_with_model(message, history=None): # Ignore history
|
|
52 |
try:
|
53 |
if not isinstance(message, str):
|
54 |
return "Error: User input must be a string"
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
full_prompt = (
|
58 |
-
"
|
59 |
-
"
|
60 |
-
"
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
72 |
return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
|
73 |
except Exception as e:
|
@@ -76,8 +83,8 @@ def chat_with_model(message, history=None): # Ignore history
|
|
76 |
# Create Gradio interface
|
77 |
interface = gr.ChatInterface(
|
78 |
fn=chat_with_model,
|
79 |
-
title="Financial Advisor Chatbot (
|
80 |
-
description="Ask for advice on starting to invest in the stock market! Powered by
|
81 |
examples=[
|
82 |
"Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?",
|
83 |
"You have a list of companies you recommend?"
|
|
|
34 |
)
|
35 |
}
|
36 |
|
37 |
+
# Load model with optimizations
|
38 |
model_name = "distilgpt2"
|
39 |
try:
|
40 |
tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
|
41 |
+
tokenizer.pad_token = tokenizer.eos_token # Ensure pad token is set
|
42 |
model = AutoModelForCausalLM.from_pretrained(
|
43 |
model_name,
|
44 |
device_map="auto",
|
45 |
+
torch_dtype=torch.float16,
|
46 |
+
low_cpu_mem_usage=True # Optimize memory usage
|
47 |
).to(device)
|
48 |
+
model.eval() # Set model to evaluation mode for faster inference
|
49 |
except Exception as e:
|
50 |
+
print(f"Error loading distilgpt2: {e}")
|
51 |
exit()
|
52 |
|
53 |
# Define chat function
|
|
|
55 |
try:
|
56 |
if not isinstance(message, str):
|
57 |
return "Error: User input must be a string"
|
58 |
+
# Normalize message for cache lookup (case-insensitive, strip whitespace)
|
59 |
+
message = message.strip().lower()
|
60 |
+
for cached_message, response in response_cache.items():
|
61 |
+
if cached_message.lower() == message:
|
62 |
+
return response
|
63 |
+
# Simplified prompt
|
64 |
full_prompt = (
|
65 |
+
"Financial advisor: Answer directly about stock market investments. "
|
66 |
+
"No specific company picks without data; suggest ETFs or general advice. "
|
67 |
+
f"User: {message}\nAssistant:"
|
68 |
+
)
|
69 |
+
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=256).to(device)
|
70 |
+
with torch.no_grad(): # Disable gradient computation for faster inference
|
71 |
+
outputs = model.generate(
|
72 |
+
**inputs,
|
73 |
+
max_new_tokens=50, # Increased slightly for better responses
|
74 |
+
do_sample=False, # Greedy decoding for speed
|
75 |
+
num_beams=1, # Disable beam search for faster generation
|
76 |
+
pad_token_id=tokenizer.eos_token_id
|
77 |
+
)
|
78 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
79 |
return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
|
80 |
except Exception as e:
|
|
|
83 |
# Create Gradio interface
|
84 |
interface = gr.ChatInterface(
|
85 |
fn=chat_with_model,
|
86 |
+
title="Financial Advisor Chatbot (DistilGPT-2)",
|
87 |
+
description="Ask for advice on starting to invest in the stock market! Powered by DistilGPT-2. Provides single, direct answers without conversation history.",
|
88 |
examples=[
|
89 |
"Hi, pretend you are a financial advisor. Now tell me how can I start investing in stock market?",
|
90 |
"You have a list of companies you recommend?"
|