FinChat / app.py
AnilNiraula's picture
Update app.py
5c97638 verified
raw
history blame
7.78 kB
import logging
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import difflib
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Define device (force CPU for Spaces free tier)
device = torch.device("cpu")
logger.info(f"Using device: {device}")
# Response cache with expanded entries
response_cache = {
"hi": "Hello! I'm your financial advisor. How can I help with investing?",
"hello": "Hello! I'm your financial advisor. How can I help with investing?",
"hey": "Hi there! Ready to discuss investment goals?",
"hi, give me step-by-step investing advice": (
"Here’s a step-by-step guide to start investing:\n"
"1. Open a brokerage account (e.g., Fidelity, Vanguard) if 18 or older.\n"
"2. Deposit an affordable amount, like $100, after an emergency fund.\n"
"3. Research and buy an ETF (e.g., VOO) using Yahoo Finance.\n"
"4. Monitor monthly and enable dividend reinvesting.\n"
"5. Use dollar-cost averaging ($100 monthly) to reduce risk.\n"
"6. Diversify across sectors.\n"
"Consult a financial planner."
),
"hi, pretend you are a financial advisor. now tell me how can i start investing in stock market?": (
"Here’s a guide to start investing:\n"
"1. Learn from Investopedia or 'The Intelligent Investor.'\n"
"2. Set goals (e.g., retirement) and assess risk.\n"
"3. Choose a brokerage (Fidelity, Vanguard).\n"
"4. Start with ETFs (e.g., VOO) or mutual funds.\n"
"5. Use dollar-cost averaging ($100-$500 monthly).\n"
"6. Diversify and monitor.\n"
"Consult a financial planner."
),
"do you have a list of companies you recommend?": (
"I can’t recommend specific companies without data. Try ETFs like VOO (S&P 500) or QQQ (tech). "
"Research technology (e.g., Apple) or healthcare (e.g., Johnson & Johnson) on Yahoo Finance. "
"Consult a financial planner."
),
"how do i start investing in stocks?": (
"Learn from Investopedia. Set goals and assess risk. Open a brokerage account (Fidelity, Vanguard) "
"and start with ETFs (e.g., VOO). Consult a financial planner."
),
"what's the difference between stocks and bonds?": (
"Stocks are company ownership with high risk and growth potential. Bonds are loans to companies/governments "
"with lower risk and steady interest. Diversify for balance."
),
"how much should i invest?": (
"Invest what you can afford after expenses and an emergency fund. Start with $100-$500 monthly "
"in ETFs (e.g., VOO) using dollar-cost averaging. Consult a financial planner."
),
"what is dollar-cost averaging?": (
"Dollar-cost averaging is investing a fixed amount regularly (e.g., $100 monthly) in ETFs, "
"reducing risk by spreading purchases over time."
),
"give me few investing idea": (
"Here are investing ideas:\n"
"1. Open a brokerage account (e.g., Fidelity) if 18 or older.\n"
"2. Deposit $100 or what you can afford.\n"
"3. Buy a researched ETF (e.g., VOO) or index fund.\n"
"4. Check regularly and enable dividend reinvesting.\n"
"5. Use dollar-cost averaging (e.g., monthly buys).\n"
"Consult a financial planner."
),
"give me investing tips": (
"Here are investing tips:\n"
"1. Educate yourself with Investopedia or books.\n"
"2. Open a brokerage account (e.g., Vanguard).\n"
"3. Start small with ETFs like VOO.\n"
"4. Invest regularly using dollar-cost averaging.\n"
"5. Diversify to manage risk.\n"
"Consult a financial planner."
)
}
# Load model and tokenizer
model_name = "distilgpt2" # Smaller model for CPU
try:
logger.info(f"Loading tokenizer for {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
logger.info(f"Loading model {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_4bit=True # 4-bit quantization
).to(device)
except Exception as e:
logger.error(f"Error loading model/tokenizer: {e}")
raise
# Pre-tokenize prompt prefix
prompt_prefix = (
"You are a financial advisor. Provide concise, numbered list advice for investing prompts. "
"Avoid repetition and vague statements.\n\n"
"Example: Q: Give investing tips\nA: 1. Open a brokerage account.\n2. Start with ETFs like VOO.\n3. Use dollar-cost averaging.\n\n"
"Q: "
)
prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
# Fuzzy matching for cache
def get_closest_cache_key(message, cache_keys, threshold=0.9):
matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
return matches[0] if matches else None
# Define chat function
def chat_with_model(message, history=None):
try:
logger.info(f"Processing message: {message}")
# Normalize and check cache
cache_key = message.lower().strip()
cache_keys = list(response_cache.keys())
closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
if closest_key:
logger.info(f"Cache hit for: {closest_key}")
return response_cache[closest_key]
# Skip model for short prompts
if len(message.strip()) <= 5:
logger.info("Short prompt, returning default response")
return "Hello! I'm your financial advisor. Ask about investing!"
# Construct prompt
full_prompt = prompt_prefix + message + "\nA:"
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
# Generate response with mixed precision
with torch.cpu.amp.autocast(), torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=80,
min_length=15,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("Generated response")
torch.cuda.empty_cache() # Clear memory
return response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
except Exception as e:
logger.error(f"Error generating response: {e}")
return f"Error: {str(e)}"
# Create Gradio interface
logger.info("Initializing Gradio interface")
interface = gr.ChatInterface(
fn=chat_with_model,
title="Financial Advisor Chatbot (DistilGPT2)",
description="Ask about investing! Fast, detailed answers on CPU.",
examples=[
"Hi",
"Hi, give me step-by-step investing advice",
"Give me few investing idea",
"Give me investing tips",
"Do you have a list of companies you recommend?",
"What's the difference between stocks and bonds?"
]
)
# Launch interface (conditional for Spaces)
if __name__ == "__main__" and not os.getenv("HF_SPACE"):
logger.info("Launching Gradio interface locally")
try:
interface.launch(share=False, debug=True, queue=False)
except Exception as e:
logger.error(f"Error launching interface: {e}")
raise
else:
logger.info("Running in Hugging Face Spaces, interface defined but not launched")