import logging import os import time # Added for timing logs import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import difflib import json # Set up logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Define device (force CPU for Spaces free tier) device = torch.device("cpu") logger.info(f"Using device: {device}") # Expanded response cache with new entries response_cache = { "hi": "Hello! I'm FinChat, your financial advisor. How can I help with investing today?", "hello": "Hello! I'm FinChat, your financial advisor. How can I help with investing today?", "hey": "Hi there! Ready to discuss investment goals with FinChat?", "how can i start investing with $100 a month?": ( "Here’s a step-by-step guide to start investing with $100 a month:\n" "1. **Open a brokerage account** with a platform like Fidelity or Robinhood. They offer low fees and no minimums.\n" "2. **Deposit your $100 monthly**. You can set up automatic transfers from your bank.\n" "3. **Choose a low-cost ETF** like VOO, which tracks the S&P 500 for broad market exposure.\n" "4. **Set up automatic investments** to buy shares regularly, reducing the impact of market fluctuations.\n" "5. **Track your progress** every few months to stay on top of your investments.\n" "Consult a financial planner for personalized advice." ), "where can i open a brokerage account?": ( "You can open a brokerage account with platforms like Vanguard, Fidelity, Charles Schwab, or Robinhood. " "They are beginner-friendly and offer low fees. Choose one that fits your needs and sign up online." ), "start investing with 100 dollars a month": ( "Here’s how to start investing with $100 a month:\n" "1. **Open a brokerage account** with a platform like Fidelity or Robinhood.\n" "2. **Deposit $100 monthly** via automatic transfers.\n" "3. **Invest in a low-cost ETF** like VOO for diversification.\n" "4. **Use dollar-cost averaging** to invest regularly.\n" "5. **Monitor your investments** quarterly.\n" "Consult a financial planner for tailored advice." ), "best places to open a brokerage account": ( "The best places to open a brokerage account include Vanguard, Fidelity, Charles Schwab, and Robinhood. " "They offer low fees, no minimums, and user-friendly platforms for beginners." ), "what is dollar-cost averaging?": ( "Dollar-cost averaging is investing a fixed amount regularly (e.g., $100 monthly) in ETFs, " "reducing risk by spreading purchases over time." ), "how much should i invest?": ( "Invest what you can afford after expenses and an emergency fund. Start with $100-$500 monthly " "in ETFs like VOO using dollar-cost averaging. Consult a financial planner." ), } # Load persistent cache cache_file = "cache.json" try: if os.path.exists(cache_file): with open(cache_file, 'r') as f: response_cache.update(json.load(f)) logger.info("Loaded persistent cache from cache.json") except Exception as e: logger.warning(f"Failed to load cache.json: {e}") # Load model and tokenizer model_name = "distilgpt2" try: logger.info(f"Loading tokenizer for {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False) logger.info(f"Loading model {model_name}") with torch.inference_mode(): model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True ).to(device) except Exception as e: logger.error(f"Error loading model/tokenizer: {e}") raise RuntimeError(f"Failed to load model: {str(e)}") # Updated prompt prefix with better instructions and examples prompt_prefix = ( "You are FinChat, a financial advisor. Always provide clear, step-by-step answers to the user's exact question. " "Avoid vague or unrelated topics. Use a numbered list format where appropriate and explain each step.\n\n" "Example 1:\n" "Q: How can I start investing with $100 a month?\n" "A: Here’s a step-by point-by-step guide:\n" "1. Open a brokerage account with a platform like Fidelity or Robinhood. They offer low fees and no minimums.\n" "2. Deposit your $100 monthly. You can set up automatic transfers.\n" "3. Choose a low-cost ETF like VOO, which tracks the S&P 500.\n" "4. Set up automatic investments to buy shares regularly.\n" "5. Track your progress every few months.\n\n" "Example 2:\n" "Q: Where can I open a brokerage account?\n" "A: You can open an account with platforms like Vanguard, Fidelity, Charles Schwab, or Robinhood. " "They are beginner-friendly and have low fees.\n\n" "Q: " ) # Fuzzy matching for cache def get_closest_cache_key(message, cache_keys, threshold=0.7): matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold) return matches[0] if matches else None # Define chat function with optimized generation parameters def chat_with_model(user_input, history=None): try: start_time = time.time() # Start timing logger.info(f"Processing user input: {user_input}") cache_key = user_input.lower().strip() cache_keys = list(response_cache.keys()) closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys) if closest_key: logger.info(f"Cache hit for: {closest_key}") response = response_cache[closest_key] logger.info(f"Chatbot response: {response}") history = history or [] history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": response}) end_time = time.time() logger.info(f"Response time: {end_time - start_time:.2f} seconds") return response, history if len(user_input.strip()) <= 5: logger.info("Short prompt, returning default response") response = "Hello! I'm FinChat, your financial advisor. Ask about investing!" logger.info(f"Chatbot response: {response}") history = history or [] history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": response}) end_time = time.time() logger.info(f"Response time: {end_time - start_time:.2f} seconds") return response, history full_prompt = prompt_prefix + user_input + "\nA:" inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device) with torch.inference_mode(): gen_start_time = time.time() # Start generation timing outputs = model.generate( **inputs, max_new_tokens=75, # Reduced for faster generation min_length=20, do_sample=False, # Use greedy decoding for speed repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id ) gen_end_time = time.time() logger.info(f"Generation time: {gen_end_time - gen_start_time:.2f} seconds") response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response logger.info(f"Chatbot response: {response}") response_cache[cache_key] = response logger.info("Cache miss, added to in-memory cache") history = history or [] history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": response}) torch.cuda.empty_cache() end_time = time.time() logger.info(f"Total response time: {end_time - start_time:.2f} seconds") return response, history except Exception as e: logger.error(f"Error generating response: {e}") response = f"Error: {str(e)}" logger.info(f"Chatbot response: {response}") history = history or [] history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": response}) return response, history # Create Gradio interface with gr.Blocks( title="FinChat: An LLM based on distilgpt2 model", css=".feedback {display: flex; gap: 10px; justify-content: center; margin-top: 10px;}" ) as interface: gr.Markdown( """ # FinChat: An LLM based on distilgpt2 model FinChat provides financial advice using the lightweight distilgpt2 model, optimized for fast, detailed responses. Ask about investing strategies, ETFs, stocks, or budgeting to get started! """ ) chatbot = gr.Chatbot(type="messages") msg = gr.Textbox(label="Your message") submit = gr.Button("Send") clear = gr.Button("Clear") def submit_message(user_input, history): response, updated_history = chat_with_model(user_input, history) return "", updated_history # Clear input, update chatbot submit.click( fn=submit_message, inputs=[msg, chatbot], outputs=[msg, chatbot] ) clear.click( fn=lambda: ("", []), # Clear input and chatbot outputs=[msg, chatbot] ) # Launch interface (conditional for Spaces) if __name__ == "__main__" and not os.getenv("HF_SPACE"): logger.info("Launching Gradio interface locally") try: interface.launch(share=False, debug=True) except Exception as e: logger.error(f"Error launching interface: {e}") raise else: logger.info("Running in Hugging Face Spaces, interface defined but not launched")