import os
import sys
import subprocess
import re
from collections.abc import Iterator
import gradio as gr
from huggingface_hub import hf_hub_download, login
# Install llama-cpp-python if not present
try:
from llama_cpp import Llama
except ModuleNotFoundError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python"])
from llama_cpp import Llama
# Install yfinance if not present (for CAGR calculations)
try:
import yfinance as yf
except ModuleNotFoundError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
import yfinance as yf
# Import pandas for handling DataFrame column structures
import pandas as pd
# Additional imports for visualization and file handling
try:
import matplotlib.pyplot as plt
from PIL import Image
import io
except ModuleNotFoundError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
import matplotlib.pyplot as plt
from PIL import Image
import io
# Additional imports for PEFT fine-tuning
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_dataset
import accelerate # Ensures accelerator compatibility
except ModuleNotFoundError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "transformers", "peft", "trl", "datasets", "accelerate", "bitsandbytes"])
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_dataset
MAX_MAX_NEW_TOKENS = 1024 # Further increased max cap to allow for complete responses
DEFAULT_MAX_NEW_TOKENS = 256 # Reduced default for faster responses
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "512")) # Reduced for speed
DESCRIPTION = """\
# FinChat: Investing Q&A (CPU-Only, Ultra-Fast Optimization)
This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.
🔎 Model details are available at the [xai-org/grok-2](https://huggingface.co/xai-org/grok-2) repository; pre-trained and instruction-tuned for multimodal capabilities, further adapted here for finance with PEFT fine-tuning on financial Q&A data.
Running on CPU 🥶 Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 256 tokens maximum. For longer, more complete responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds reach 20-40 tokens per second. CAGR calculations for stocks are now computed accurately using historical data.
"""
LICENSE = """\
---
This application employs the [xai-org/grok-2](https://huggingface.co/xai-org/grok-2) model, governed by xAI's Terms of Use. Refer to the [model card](https://huggingface.co/xai-org/grok-2) and [Grok documentation](https://x.ai/docs/terms) for details.
"""
# Define paths
base_model_id = "xai-org/grok-2"
fine_tuned_model_path = "fine_tuned_grok.gguf"
quantized_model_path = "grok-2-finetuned.Q4_K_M.gguf"
lora_adapter_path = "lora_adapter" # Temporary for PEFT
# Hugging Face login (replace with your token or set as env var)
hf_token = os.getenv("HF_TOKEN") # Set this in your environment (e.g., Hugging Face Space secrets)
if hf_token:
login(hf_token)
else:
print("Warning: HF_TOKEN environment variable not set. Fine-tuning may fail due to lack of authentication for gated models. Please set HF_TOKEN in your Space secrets after accepting terms at https://huggingface.co/xai-org/grok-2.")
# One-time fine-tuning process if the fine-tuned GGUF does not exist
if not os.path.exists(quantized_model_path):
print("Attempting one-time PEFT fine-tuning on Finance-Alpaca dataset... This may take significant time on CPU.")
try:
# Load tokenizer and base model (use bfloat16 for CPU memory efficiency; note: fine-tuning on CPU is slow)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16, # Efficient for CPU, assumes sufficient RAM (~6GB for 3B model)
device_map="cpu" # Enforce CPU
)
# Prepare dataset: Use Finance-Alpaca for financial Q&A examples (subset for efficiency)
dataset = load_dataset("gbharti/finance-alpaca", split="train[0:500]")
# Format function for Alpaca-style prompts adapted to Grok-2
def formatting_func(example):
text = f"user\n{example['instruction']}\n{example['input']}\nmodel\n{example['output']}\n"
return {"text": text}
dataset = dataset.map(formatting_func)
# PEFT LoRA configuration (efficient for fine-tuning)
lora_config = LoraConfig(
r=8, # LoRA rank
lora_alpha=16,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Key modules for Grok-2 (assuming similar to Gemma)
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
# Training arguments (small batch/epocs for CPU feasibility)
training_args = TrainingArguments(
output_dir=lora_adapter_path,
num_train_epochs=1, # Limited for speed
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=False, # Use bf16 on CPU
save_steps=100,
logging_steps=10,
optim="adamw_torch",
report_to="none" # No external logging
)
# SFT Trainer for supervised fine-tuning
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
args=training_args
)
trainer.train()
# Merge LoRA adapter with base model
model = model.merge_and_unload()
model.save_pretrained("merged_model")
tokenizer.save_pretrained("merged_model")
# Convert merged HF model to GGUF (requires cloning llama.cpp)
subprocess.check_call(["git", "clone", "https://github.com/ggerganov/llama.cpp"])
os.chdir("llama.cpp")
subprocess.check_call(["make"])
subprocess.check_call([sys.executable, "convert_hf_to_gguf.py", "--outfile", "../" + fine_tuned_model_path, "--outtype", "f16", "../merged_model"])
# Quantize to Q4_K_M
subprocess.check_call(["./quantize", "../" + fine_tuned_model_path, "../" + quantized_model_path, "Q4_K_M"])
os.chdir("..")
print("Fine-tuning and conversion complete. Using fine-tuned model.")
except Exception as e:
print(f"Error during fine-tuning: {str(e)}")
print("""
To resolve gated model access issues:
1. Log in to your Hugging Face account.
2. Visit https://huggingface.co/xai-org/grok-2 and accept the license terms.
3. Go to your Hugging Face settings > Access Tokens, create a new read token.
4. In this Hugging Face Space, go to Settings > Secrets, add a new secret named 'HF_TOKEN' with the token value.
5. Restart the Space.
Falling back to the original non-fine-tuned model.
""")
# Download or use fine-tuned GGUF model
try:
model_path = quantized_model_path if os.path.exists(quantized_model_path) else hf_hub_download(repo_id="xai-org/grok-2-GGUF", filename="grok-2.Q4_K_M.gguf")
except Exception as e:
print(f"Error downloading GGUF model: {str(e)} Falling back to a placeholder or alternative model if available.")
# Fallback to Gemma GGUF as alternative
model_path = hf_hub_download(repo_id="mradermacher/gemma-3n-E4B-it-GGUF", filename="gemma-3n-E4B-it.Q4_K_M.gguf")
# Load the model with optimizations and chat format for Grok-2 (assuming similar to Gemma; adjust if needed)
llm = Llama(
model_path=model_path,
n_ctx=512, # Further reduced for faster prompt evaluation
n_batch=1024, # Increased batch size for faster generation
n_threads=multiprocessing.cpu_count(), # Use all available CPU threads for faster processing
n_gpu_layers=0, # Enforce CPU-only execution
chat_format="gemma" # Use built-in chat format (assuming compatibility; change if Grok-2 has specific format)
)
DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, detailed, and comprehensive answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
Always respond exclusively in English. Do not use any other language in your responses.
Limit responses to under 100 words and use bullet points extensively for clarity.
Example:
User: average return for TSLA between 2010 and 2020
Assistant: The compound annual growth rate (CAGR) for TSLA stock from 2010 to 2020 is approximately 63.01%. This represents the average annual return over that period, accounting for compounding. Note that past performance is not indicative of future results, and I recommend consulting a financial advisor for personalized advice."""
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
lower_message = message.lower().strip()
if lower_message in ["hi", "hello"]:
response = "I'm FinChat, your financial advisor. Ask me anything Fin related!"
yield response
return
# Dynamically adjust max_new_tokens based on query length (heuristic: up to 3x query words, capped at MAX_MAX_NEW_TOKENS)
estimated_tokens = len(message.split()) * 3 # Increased multiplier to allow for more complete responses
dynamic_max_tokens = min(max(estimated_tokens, DEFAULT_MAX_NEW_TOKENS), MAX_MAX_NEW_TOKENS)
# Check for CAGR/average return queries, now supporting multiple tickers
match = re.match(r'(?:average return|cagr) for ([\w\s,]+(?:and [\w\s,]+)?) between (\d{4}) and (\d{4})', lower_message)
if match:
tickers_str, start_year, end_year = match.groups()
# Parse tickers: split by comma, strip spaces, handle "and"
tickers = [t.strip().upper() for t in re.split(r',|\band\b', tickers_str) if t.strip()]
responses = []
if int(end_year) <= int(start_year):
yield "The specified time period is invalid (end year must be after start year)."
return
for ticker in tickers:
try:
data = yf.download(ticker, start=f"{start_year}-01-01", end=f"{end_year}-12-31")
if not data.empty:
# Handle multi-index columns from recent yfinance versions
if isinstance(data.columns, pd.MultiIndex):
data = data.droplevel('Ticker', axis=1)
initial = data['Close'].iloc[0]
final = data['Close'].iloc[-1]
start_date = data.index[0]
end_date = data.index[-1]
days = (end_date - start_date).days
years = days / 365.25
if years > 0:
cagr = ((final / initial) ** (1 / years) - 1) * 100
responses.append(f"{ticker}: approximately {cagr:.2f}%")
else:
responses.append(f"{ticker}: Invalid period (no elapsed time).")
else:
responses.append(f"{ticker}: No historical data available between {start_year} and {end_year}.")
except Exception as e:
responses.append(f"{ticker}: Error calculating CAGR - {str(e)}")
full_response = f"The compound annual growth rates (CAGR) for the requested stocks from {start_year} to {end_year} are:\n"
full_response += "\n".join(responses)
full_response += "\nThese represent the average annual returns over that period, accounting for compounding. Note that past performance is not indicative of future results, and I recommend consulting a financial advisor for personalized advice."
yield full_response
return
# Build conversation messages
conversation = []
if system_prompt:
conversation.append({"role": "user", "content": system_prompt})
for user, assistant in chat_history:
conversation.append({"role": "user", "content": user})
conversation.append({"role": "model", "content": assistant})
conversation.append({"role": "user", "content": message})
# Approximate token length check (join contents)
prompt_text = "\n".join(d["content"] for d in conversation)
input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
if len(input_tokens) > MAX_INPUT_TOKEN_LENGTH:
yield "Error: Input too long. Please shorten your query."
return
# Use create_chat_completion for clean output without special tokens
response = ""
stream = llm.create_chat_completion(
messages=conversation,
max_tokens=dynamic_max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repetition_penalty,
stream=True
)
for chunk in stream:
delta = chunk["choices"][0]["delta"]
if "content" in delta and delta["content"] is not None:
response += delta["content"]
yield response
if chunk["choices"][0]["finish_reason"] is not None:
break
def process_portfolio(ticker1, shares1, cost1, price1, ticker2, shares2, cost2, price2, ticker3, shares3, cost3, price3, growth_rate):
portfolio = {}
if ticker1:
value1 = shares1 * price1
portfolio[ticker1.upper()] = {'shares': shares1, 'cost': cost1, 'price': price1, 'value': value1}
if ticker2:
value2 = shares2 * price2
portfolio[ticker2.upper()] = {'shares': shares2, 'cost': cost2, 'price': price2, 'value': value2}
if ticker3:
value3 = shares3 * price3
portfolio[ticker3.upper()] = {'shares': shares3, 'cost': cost3, 'price': price3, 'value': value3}
if not portfolio:
return "No portfolio data provided.", None
total_value_now = sum(v['value'] for v in portfolio.values())
allocations = {k: v['value'] / total_value_now for k, v in portfolio.items()} if total_value_now > 0 else {}
# Generate pie chart for allocation
fig_alloc, ax_alloc = plt.subplots()
ax_alloc.pie(allocations.values(), labels=allocations.keys(), autopct='%1.1f%%')
ax_alloc.set_title('Portfolio Allocation')
buf_alloc = io.BytesIO()
fig_alloc.savefig(buf_alloc, format='png')
buf_alloc.seek(0)
chart_alloc = Image.open(buf_alloc)
# Projections
def project_value(value, years, rate):
return value * (1 + rate / 100) ** years
total_value_1yr = sum(project_value(v['value'], 1, growth_rate) for v in portfolio.values())
total_value_5yr = sum(project_value(v['value'], 5, growth_rate) for v in portfolio.values())
total_value_10yr = sum(project_value(v['value'], 10, growth_rate) for v in portfolio.values())
data_str = "User portfolio:\n" + "\n".join(f"- {k}: {v['shares']} shares, avg cost {v['cost']}, current price {v['price']}, value {v['value']:.2f}" for k,v in portfolio.items()) + f"\nTotal value now: {total_value_now:.2f}\nProjected (at {growth_rate}% annual growth):\n- 1 year: {total_value_1yr:.2f}\n- 5 years: {total_value_5yr:.2f}\n- 10 years: {total_value_10yr:.2f}"
return data_str, chart_alloc
# Gradio interface setup
with gr.Blocks(css="""#chatbot {height: 500px;}""") as demo:
gr.Markdown(DESCRIPTION)
chatbot = gr.Chatbot(label="FinChat")
msg = gr.Textbox(label="Ask a finance question", placeholder="e.g., 'What is CAGR?' or 'Average return for AAPL between 2010 and 2020'")
with gr.Row():
with gr.Column():
ticker1 = gr.Textbox(label="Ticker 1")
shares1 = gr.Number(label="Shares 1")
cost1 = gr.Number(label="Avg Cost/Share 1")
price1 = gr.Number(label="Current Price 1")
with gr.Column():
ticker2 = gr.Textbox(label="Ticker 2")
shares2 = gr.Number(label="Shares 2")
cost2 = gr.Number(label="Avg Cost/Share 2")
price2 = gr.Number(label="Current Price 2")
with gr.Column():
ticker3 = gr.Textbox(label="Ticker 3")
shares3 = gr.Number(label="Shares 3")
cost3 = gr.Number(label="Avg Cost/Share 3")
price3 = gr.Number(label="Current Price 3")
growth_rate = gr.Slider(minimum=5, maximum=50, step=5, value=10, label="Annual Growth Rate (%)")
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
with gr.Accordion("Advanced Settings", open=False):
system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
temperature = gr.Slider(label="Temperature", value=0.6, minimum=0.0, maximum=1.0, step=0.05)
top_p = gr.Slider(label="Top P", value=0.9, minimum=0.0, maximum=1.0, step=0.05)
top_k = gr.Slider(label="Top K", value=50, minimum=1, maximum=100, step=1)
repetition_penalty = gr.Slider(label="Repetition Penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05)
max_new_tokens = gr.Slider(label="Max New Tokens", value=DEFAULT_MAX_NEW_TOKENS, minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1)
gr.Markdown(LICENSE)
def user(message, history):
return "", history + [[message, None]]
def bot(history, sys_prompt, temp, tp, tk, rp, mnt, ticker1, shares1, cost1, price1, ticker2, shares2, cost2, price2, ticker3, shares3, cost3, price3, growth_rate):
message = history[-1][0]
portfolio_data, chart_alloc = process_portfolio(ticker1, shares1, cost1, price1, ticker2, shares2, cost2, price2, ticker3, shares3, cost3, price3, growth_rate)
message += "\n" + portfolio_data
history[-1][1] = ""
for new_text in generate(message, history[:-1], sys_prompt, mnt, temp, tp, tk, rp):
history[-1][1] = new_text
yield history
if chart_alloc:
history[-1][1] = [history[-1][1], chart_alloc]
yield history
submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, ticker1, shares1, cost1, price1, ticker2, shares2, cost2, price2, ticker3, shares3, cost3, price3, growth_rate], chatbot
)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, ticker1, shares1, cost1, price1, ticker2, shares2, cost2, price2, ticker3, shares3, cost3, price3, growth_rate], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue(max_size=128).launch()