add support for caching results
Browse files
app.py
CHANGED
|
@@ -2,6 +2,90 @@ import numpy as np
|
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import gradio as gr
|
| 4 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def run_simulation(
|
| 7 |
initial_investment: float,
|
|
@@ -21,6 +105,20 @@ def run_simulation(
|
|
| 21 |
num_swr_intervals: int,
|
| 22 |
progress=gr.Progress()
|
| 23 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
swr_test_step = (max_swr_test - min_swr_test) / num_swr_intervals if num_swr_intervals > 0 else 0.1 # Calculate step
|
| 25 |
progress(0, desc="Starting simulation...")
|
| 26 |
start_time = time.time()
|
|
@@ -174,6 +272,9 @@ def run_simulation(
|
|
| 174 |
ax2.set_xlabel("Year")
|
| 175 |
ax2.set_ylabel("Portfolio Value ($)")
|
| 176 |
|
|
|
|
|
|
|
|
|
|
| 177 |
return "Simulation Complete!", results_text, fig1, fig2
|
| 178 |
|
| 179 |
# Gradio Interface
|
|
@@ -294,7 +395,7 @@ with gr.Blocks() as demo:
|
|
| 294 |
max_swr_test = gr.Slider(minimum=0.5, maximum=10.0, value=5.0, step=0.1, label="Max SWR to Test (%)", interactive=True)
|
| 295 |
num_swr_intervals = gr.Slider(minimum=5, maximum=100, value=15, step=1, label="Number of SWR Intervals", interactive=True)
|
| 296 |
|
| 297 |
-
run_button = gr.Button("Run Simulation")
|
| 298 |
|
| 299 |
with gr.Column():
|
| 300 |
gr.Markdown("### Simulation Results")
|
|
@@ -331,4 +432,58 @@ with gr.Blocks() as demo:
|
|
| 331 |
],
|
| 332 |
outputs=[status_output, results_output, swr_plot_output, paths_plot_output]
|
| 333 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
demo.launch()
|
|
|
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import gradio as gr
|
| 4 |
import time
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import hashlib
|
| 8 |
+
import pathlib
|
| 9 |
+
from PIL import Image # Added for loading images from cache
|
| 10 |
+
|
| 11 |
+
# --- Caching Setup ---
|
| 12 |
+
CACHE_DIR = pathlib.Path("cache")
|
| 13 |
+
CACHE_DIR.mkdir(exist_ok=True)
|
| 14 |
+
|
| 15 |
+
def _generate_cache_key(*args, **kwargs):
|
| 16 |
+
"""Generates a unique hash for the given arguments."""
|
| 17 |
+
# Convert all arguments to a consistent string representation
|
| 18 |
+
# Exclude the 'progress' object from hashing as it's not part of the configuration
|
| 19 |
+
hash_input = []
|
| 20 |
+
for arg in args:
|
| 21 |
+
if not isinstance(arg, gr.Progress):
|
| 22 |
+
hash_input.append(str(arg))
|
| 23 |
+
for k, v in kwargs.items():
|
| 24 |
+
if k != 'progress':
|
| 25 |
+
hash_input.append(f"{k}={v}")
|
| 26 |
+
|
| 27 |
+
# Use a stable JSON representation for dictionary arguments if any
|
| 28 |
+
# For this specific function, all args are simple types, so direct string conversion is fine.
|
| 29 |
+
|
| 30 |
+
return hashlib.md5("".join(hash_input).encode('utf-8')).hexdigest()
|
| 31 |
+
|
| 32 |
+
def _save_to_cache(key, results_text, fig1, fig2):
|
| 33 |
+
"""Saves simulation results and plots to the cache."""
|
| 34 |
+
cache_path = CACHE_DIR / key
|
| 35 |
+
cache_path.mkdir(exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# Save plots as PNGs
|
| 38 |
+
fig1_path = cache_path / "fig1.png"
|
| 39 |
+
fig2_path = cache_path / "fig2.png"
|
| 40 |
+
fig1.savefig(fig1_path)
|
| 41 |
+
fig2.savefig(fig2_path)
|
| 42 |
+
plt.close(fig1) # Close figures to free memory
|
| 43 |
+
plt.close(fig2)
|
| 44 |
+
|
| 45 |
+
# Save metadata (results_text and plot paths)
|
| 46 |
+
metadata = {
|
| 47 |
+
"results_text": results_text,
|
| 48 |
+
"fig1_path": str(fig1_path),
|
| 49 |
+
"fig2_path": str(fig2_path)
|
| 50 |
+
}
|
| 51 |
+
with open(cache_path / "metadata.json", "w") as f:
|
| 52 |
+
json.dump(metadata, f)
|
| 53 |
+
|
| 54 |
+
def _load_from_cache(key):
|
| 55 |
+
"""Loads simulation results from the cache."""
|
| 56 |
+
cache_path = CACHE_DIR / key
|
| 57 |
+
metadata_path = cache_path / "metadata.json"
|
| 58 |
+
|
| 59 |
+
if not metadata_path.exists():
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
with open(metadata_path, "r") as f:
|
| 63 |
+
metadata = json.load(f)
|
| 64 |
+
|
| 65 |
+
# Load images from paths
|
| 66 |
+
try:
|
| 67 |
+
fig1_img = Image.open(metadata["fig1_path"])
|
| 68 |
+
fig2_img = Image.open(metadata["fig2_path"])
|
| 69 |
+
except FileNotFoundError:
|
| 70 |
+
print(f"Cached image files not found for key: {key}. Deleting cache entry.")
|
| 71 |
+
# Clean up incomplete cache entry
|
| 72 |
+
import shutil
|
| 73 |
+
shutil.rmtree(cache_path)
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
# Convert PIL Image objects back to matplotlib figures for Gradio's gr.Plot
|
| 77 |
+
# This is a workaround as gr.Plot expects matplotlib figures, not PIL Images directly.
|
| 78 |
+
fig1_recreated = plt.figure()
|
| 79 |
+
ax1_recreated = fig1_recreated.add_subplot(111)
|
| 80 |
+
ax1_recreated.imshow(fig1_img)
|
| 81 |
+
ax1_recreated.axis('off') # Hide axes for image display
|
| 82 |
+
|
| 83 |
+
fig2_recreated = plt.figure()
|
| 84 |
+
ax2_recreated = fig2_recreated.add_subplot(111)
|
| 85 |
+
ax2_recreated.imshow(fig2_img)
|
| 86 |
+
ax2_recreated.axis('off') # Hide axes for image display
|
| 87 |
+
|
| 88 |
+
return "Loaded from cache!", metadata["results_text"], fig1_recreated, fig2_recreated
|
| 89 |
|
| 90 |
def run_simulation(
|
| 91 |
initial_investment: float,
|
|
|
|
| 105 |
num_swr_intervals: int,
|
| 106 |
progress=gr.Progress()
|
| 107 |
):
|
| 108 |
+
# Generate cache key from all input parameters except 'progress'
|
| 109 |
+
cache_key = _generate_cache_key(
|
| 110 |
+
initial_investment, num_years, num_simulations, target_success_rate,
|
| 111 |
+
stock_mean_return, stock_std_dev, bond_mean_return, bond_std_dev,
|
| 112 |
+
stock_allocation, correlation_stock_bond, mean_inflation,
|
| 113 |
+
std_dev_inflation, min_swr_test, max_swr_test, num_swr_intervals
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Check if results are in cache
|
| 117 |
+
cached_results = _load_from_cache(cache_key)
|
| 118 |
+
if cached_results:
|
| 119 |
+
progress(1, desc="Loading from cache...")
|
| 120 |
+
return cached_results
|
| 121 |
+
|
| 122 |
swr_test_step = (max_swr_test - min_swr_test) / num_swr_intervals if num_swr_intervals > 0 else 0.1 # Calculate step
|
| 123 |
progress(0, desc="Starting simulation...")
|
| 124 |
start_time = time.time()
|
|
|
|
| 272 |
ax2.set_xlabel("Year")
|
| 273 |
ax2.set_ylabel("Portfolio Value ($)")
|
| 274 |
|
| 275 |
+
# Save results to cache before returning
|
| 276 |
+
_save_to_cache(cache_key, results_text, fig1, fig2)
|
| 277 |
+
|
| 278 |
return "Simulation Complete!", results_text, fig1, fig2
|
| 279 |
|
| 280 |
# Gradio Interface
|
|
|
|
| 395 |
max_swr_test = gr.Slider(minimum=0.5, maximum=10.0, value=5.0, step=0.1, label="Max SWR to Test (%)", interactive=True)
|
| 396 |
num_swr_intervals = gr.Slider(minimum=5, maximum=100, value=15, step=1, label="Number of SWR Intervals", interactive=True)
|
| 397 |
|
| 398 |
+
run_button = gr.Button("Run Simulation", variant="primary")
|
| 399 |
|
| 400 |
with gr.Column():
|
| 401 |
gr.Markdown("### Simulation Results")
|
|
|
|
| 432 |
],
|
| 433 |
outputs=[status_output, results_output, swr_plot_output, paths_plot_output]
|
| 434 |
)
|
| 435 |
+
|
| 436 |
+
# Define default parameters for initial load
|
| 437 |
+
DEFAULT_PARAMS = {
|
| 438 |
+
"initial_investment": 1_000_000.0,
|
| 439 |
+
"num_years": 30,
|
| 440 |
+
"num_simulations": 5000,
|
| 441 |
+
"target_success_rate": 95,
|
| 442 |
+
"stock_mean_return": 9.0,
|
| 443 |
+
"stock_std_dev": 15.0,
|
| 444 |
+
"bond_mean_return": 4.0,
|
| 445 |
+
"bond_std_dev": 5.0,
|
| 446 |
+
"stock_allocation": 60,
|
| 447 |
+
"correlation_stock_bond": -0.2,
|
| 448 |
+
"mean_inflation": 2.5,
|
| 449 |
+
"std_dev_inflation": 1.5,
|
| 450 |
+
"min_swr_test": 2.5,
|
| 451 |
+
"max_swr_test": 5.0,
|
| 452 |
+
"num_swr_intervals": 15
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
def load_default_results():
|
| 456 |
+
"""Loads cached results for default parameters if available."""
|
| 457 |
+
# Ensure the order of parameters matches _generate_cache_key
|
| 458 |
+
default_key = _generate_cache_key(
|
| 459 |
+
DEFAULT_PARAMS["initial_investment"],
|
| 460 |
+
DEFAULT_PARAMS["num_years"],
|
| 461 |
+
DEFAULT_PARAMS["num_simulations"],
|
| 462 |
+
DEFAULT_PARAMS["target_success_rate"],
|
| 463 |
+
DEFAULT_PARAMS["stock_mean_return"],
|
| 464 |
+
DEFAULT_PARAMS["stock_std_dev"],
|
| 465 |
+
DEFAULT_PARAMS["bond_mean_return"],
|
| 466 |
+
DEFAULT_PARAMS["bond_std_dev"],
|
| 467 |
+
DEFAULT_PARAMS["stock_allocation"],
|
| 468 |
+
DEFAULT_PARAMS["correlation_stock_bond"],
|
| 469 |
+
DEFAULT_PARAMS["mean_inflation"],
|
| 470 |
+
DEFAULT_PARAMS["std_dev_inflation"],
|
| 471 |
+
DEFAULT_PARAMS["min_swr_test"],
|
| 472 |
+
DEFAULT_PARAMS["max_swr_test"],
|
| 473 |
+
DEFAULT_PARAMS["num_swr_intervals"]
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
cached = _load_from_cache(default_key)
|
| 477 |
+
if cached:
|
| 478 |
+
return cached
|
| 479 |
+
else:
|
| 480 |
+
# Return empty/placeholder values if no cache hit
|
| 481 |
+
return "No default cache found. Run simulation.", "", None, None
|
| 482 |
+
|
| 483 |
+
# Load default results on app startup
|
| 484 |
+
demo.load(
|
| 485 |
+
fn=load_default_results,
|
| 486 |
+
outputs=[status_output, results_output, swr_plot_output, paths_plot_output]
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
demo.launch()
|