Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -76,36 +76,45 @@ response_cache = {
|
|
76 |
"4. Invest regularly using dollar-cost averaging.\n"
|
77 |
"5. Diversify to manage risk.\n"
|
78 |
"Consult a financial planner."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
)
|
80 |
}
|
81 |
|
82 |
# Load model and tokenizer
|
83 |
-
model_name = "distilgpt2"
|
84 |
try:
|
85 |
logger.info(f"Loading tokenizer for {model_name}")
|
86 |
tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
|
87 |
logger.info(f"Loading model {model_name}")
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
except Exception as e:
|
95 |
logger.error(f"Error loading model/tokenizer: {e}")
|
96 |
-
raise
|
97 |
|
98 |
# Pre-tokenize prompt prefix
|
99 |
prompt_prefix = (
|
100 |
-
"You are a financial advisor. Provide
|
101 |
-
"Avoid repetition
|
102 |
-
"Example: Q: Give investing tips\nA: 1. Open a brokerage
|
103 |
"Q: "
|
104 |
)
|
105 |
prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
|
106 |
|
107 |
# Fuzzy matching for cache
|
108 |
-
def get_closest_cache_key(message, cache_keys, threshold=0.
|
109 |
matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
|
110 |
return matches[0] if matches else None
|
111 |
|
@@ -130,11 +139,11 @@ def chat_with_model(message, history=None):
|
|
130 |
full_prompt = prompt_prefix + message + "\nA:"
|
131 |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
132 |
|
133 |
-
# Generate response
|
134 |
with torch.cpu.amp.autocast(), torch.no_grad():
|
135 |
outputs = model.generate(
|
136 |
**inputs,
|
137 |
-
max_new_tokens=
|
138 |
min_length=15,
|
139 |
do_sample=True,
|
140 |
temperature=0.7,
|
@@ -161,8 +170,8 @@ interface = gr.ChatInterface(
|
|
161 |
"Hi, give me step-by-step investing advice",
|
162 |
"Give me few investing idea",
|
163 |
"Give me investing tips",
|
164 |
-
"
|
165 |
-
"
|
166 |
]
|
167 |
)
|
168 |
|
|
|
76 |
"4. Invest regularly using dollar-cost averaging.\n"
|
77 |
"5. Diversify to manage risk.\n"
|
78 |
"Consult a financial planner."
|
79 |
+
),
|
80 |
+
"how to start investing": (
|
81 |
+
"Here’s how to start investing:\n"
|
82 |
+
"1. Study basics on Investopedia.\n"
|
83 |
+
"2. Open a brokerage account (e.g., Fidelity).\n"
|
84 |
+
"3. Deposit $100 or more after securing savings.\n"
|
85 |
+
"4. Buy an ETF like VOO after research.\n"
|
86 |
+
"5. Invest monthly with dollar-cost averaging.\n"
|
87 |
+
"Consult a financial planner."
|
88 |
)
|
89 |
}
|
90 |
|
91 |
# Load model and tokenizer
|
92 |
+
model_name = "distilgpt2"
|
93 |
try:
|
94 |
logger.info(f"Loading tokenizer for {model_name}")
|
95 |
tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
|
96 |
logger.info(f"Loading model {model_name}")
|
97 |
+
with torch.no_grad():
|
98 |
+
model = AutoModelForCausalLM.from_pretrained(
|
99 |
+
model_name,
|
100 |
+
torch_dtype=torch.float16,
|
101 |
+
low_cpu_mem_usage=True
|
102 |
+
).to(device)
|
103 |
except Exception as e:
|
104 |
logger.error(f"Error loading model/tokenizer: {e}")
|
105 |
+
raise RuntimeError(f"Failed to load model: {str(e)}")
|
106 |
|
107 |
# Pre-tokenize prompt prefix
|
108 |
prompt_prefix = (
|
109 |
+
"You are a financial advisor. Provide numbered list advice for investing prompts. "
|
110 |
+
"Avoid repetition.\n\n"
|
111 |
+
"Example: Q: Give investing tips\nA: 1. Open a brokerage.\n2. Buy ETFs like VOO.\n3. Use dollar-cost averaging.\n\n"
|
112 |
"Q: "
|
113 |
)
|
114 |
prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
|
115 |
|
116 |
# Fuzzy matching for cache
|
117 |
+
def get_closest_cache_key(message, cache_keys, threshold=0.85):
|
118 |
matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
|
119 |
return matches[0] if matches else None
|
120 |
|
|
|
139 |
full_prompt = prompt_prefix + message + "\nA:"
|
140 |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
141 |
|
142 |
+
# Generate response
|
143 |
with torch.cpu.amp.autocast(), torch.no_grad():
|
144 |
outputs = model.generate(
|
145 |
**inputs,
|
146 |
+
max_new_tokens=60,
|
147 |
min_length=15,
|
148 |
do_sample=True,
|
149 |
temperature=0.7,
|
|
|
170 |
"Hi, give me step-by-step investing advice",
|
171 |
"Give me few investing idea",
|
172 |
"Give me investing tips",
|
173 |
+
"How to start investing",
|
174 |
+
"Do you have a list of companies you recommend?"
|
175 |
]
|
176 |
)
|
177 |
|