liaoch commited on
Commit
e05250c
·
1 Parent(s): def30e8

add support for caching results

Browse files
Files changed (1) hide show
  1. app.py +156 -1
app.py CHANGED
@@ -2,6 +2,90 @@ import numpy as np
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def run_simulation(
7
  initial_investment: float,
@@ -21,6 +105,20 @@ def run_simulation(
21
  num_swr_intervals: int,
22
  progress=gr.Progress()
23
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  swr_test_step = (max_swr_test - min_swr_test) / num_swr_intervals if num_swr_intervals > 0 else 0.1 # Calculate step
25
  progress(0, desc="Starting simulation...")
26
  start_time = time.time()
@@ -174,6 +272,9 @@ def run_simulation(
174
  ax2.set_xlabel("Year")
175
  ax2.set_ylabel("Portfolio Value ($)")
176
 
 
 
 
177
  return "Simulation Complete!", results_text, fig1, fig2
178
 
179
  # Gradio Interface
@@ -294,7 +395,7 @@ with gr.Blocks() as demo:
294
  max_swr_test = gr.Slider(minimum=0.5, maximum=10.0, value=5.0, step=0.1, label="Max SWR to Test (%)", interactive=True)
295
  num_swr_intervals = gr.Slider(minimum=5, maximum=100, value=15, step=1, label="Number of SWR Intervals", interactive=True)
296
 
297
- run_button = gr.Button("Run Simulation")
298
 
299
  with gr.Column():
300
  gr.Markdown("### Simulation Results")
@@ -331,4 +432,58 @@ with gr.Blocks() as demo:
331
  ],
332
  outputs=[status_output, results_output, swr_plot_output, paths_plot_output]
333
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  demo.launch()
 
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import time
5
+ import os
6
+ import json
7
+ import hashlib
8
+ import pathlib
9
+ from PIL import Image # Added for loading images from cache
10
+
11
+ # --- Caching Setup ---
12
+ CACHE_DIR = pathlib.Path("cache")
13
+ CACHE_DIR.mkdir(exist_ok=True)
14
+
15
+ def _generate_cache_key(*args, **kwargs):
16
+ """Generates a unique hash for the given arguments."""
17
+ # Convert all arguments to a consistent string representation
18
+ # Exclude the 'progress' object from hashing as it's not part of the configuration
19
+ hash_input = []
20
+ for arg in args:
21
+ if not isinstance(arg, gr.Progress):
22
+ hash_input.append(str(arg))
23
+ for k, v in kwargs.items():
24
+ if k != 'progress':
25
+ hash_input.append(f"{k}={v}")
26
+
27
+ # Use a stable JSON representation for dictionary arguments if any
28
+ # For this specific function, all args are simple types, so direct string conversion is fine.
29
+
30
+ return hashlib.md5("".join(hash_input).encode('utf-8')).hexdigest()
31
+
32
+ def _save_to_cache(key, results_text, fig1, fig2):
33
+ """Saves simulation results and plots to the cache."""
34
+ cache_path = CACHE_DIR / key
35
+ cache_path.mkdir(exist_ok=True)
36
+
37
+ # Save plots as PNGs
38
+ fig1_path = cache_path / "fig1.png"
39
+ fig2_path = cache_path / "fig2.png"
40
+ fig1.savefig(fig1_path)
41
+ fig2.savefig(fig2_path)
42
+ plt.close(fig1) # Close figures to free memory
43
+ plt.close(fig2)
44
+
45
+ # Save metadata (results_text and plot paths)
46
+ metadata = {
47
+ "results_text": results_text,
48
+ "fig1_path": str(fig1_path),
49
+ "fig2_path": str(fig2_path)
50
+ }
51
+ with open(cache_path / "metadata.json", "w") as f:
52
+ json.dump(metadata, f)
53
+
54
+ def _load_from_cache(key):
55
+ """Loads simulation results from the cache."""
56
+ cache_path = CACHE_DIR / key
57
+ metadata_path = cache_path / "metadata.json"
58
+
59
+ if not metadata_path.exists():
60
+ return None
61
+
62
+ with open(metadata_path, "r") as f:
63
+ metadata = json.load(f)
64
+
65
+ # Load images from paths
66
+ try:
67
+ fig1_img = Image.open(metadata["fig1_path"])
68
+ fig2_img = Image.open(metadata["fig2_path"])
69
+ except FileNotFoundError:
70
+ print(f"Cached image files not found for key: {key}. Deleting cache entry.")
71
+ # Clean up incomplete cache entry
72
+ import shutil
73
+ shutil.rmtree(cache_path)
74
+ return None
75
+
76
+ # Convert PIL Image objects back to matplotlib figures for Gradio's gr.Plot
77
+ # This is a workaround as gr.Plot expects matplotlib figures, not PIL Images directly.
78
+ fig1_recreated = plt.figure()
79
+ ax1_recreated = fig1_recreated.add_subplot(111)
80
+ ax1_recreated.imshow(fig1_img)
81
+ ax1_recreated.axis('off') # Hide axes for image display
82
+
83
+ fig2_recreated = plt.figure()
84
+ ax2_recreated = fig2_recreated.add_subplot(111)
85
+ ax2_recreated.imshow(fig2_img)
86
+ ax2_recreated.axis('off') # Hide axes for image display
87
+
88
+ return "Loaded from cache!", metadata["results_text"], fig1_recreated, fig2_recreated
89
 
90
  def run_simulation(
91
  initial_investment: float,
 
105
  num_swr_intervals: int,
106
  progress=gr.Progress()
107
  ):
108
+ # Generate cache key from all input parameters except 'progress'
109
+ cache_key = _generate_cache_key(
110
+ initial_investment, num_years, num_simulations, target_success_rate,
111
+ stock_mean_return, stock_std_dev, bond_mean_return, bond_std_dev,
112
+ stock_allocation, correlation_stock_bond, mean_inflation,
113
+ std_dev_inflation, min_swr_test, max_swr_test, num_swr_intervals
114
+ )
115
+
116
+ # Check if results are in cache
117
+ cached_results = _load_from_cache(cache_key)
118
+ if cached_results:
119
+ progress(1, desc="Loading from cache...")
120
+ return cached_results
121
+
122
  swr_test_step = (max_swr_test - min_swr_test) / num_swr_intervals if num_swr_intervals > 0 else 0.1 # Calculate step
123
  progress(0, desc="Starting simulation...")
124
  start_time = time.time()
 
272
  ax2.set_xlabel("Year")
273
  ax2.set_ylabel("Portfolio Value ($)")
274
 
275
+ # Save results to cache before returning
276
+ _save_to_cache(cache_key, results_text, fig1, fig2)
277
+
278
  return "Simulation Complete!", results_text, fig1, fig2
279
 
280
  # Gradio Interface
 
395
  max_swr_test = gr.Slider(minimum=0.5, maximum=10.0, value=5.0, step=0.1, label="Max SWR to Test (%)", interactive=True)
396
  num_swr_intervals = gr.Slider(minimum=5, maximum=100, value=15, step=1, label="Number of SWR Intervals", interactive=True)
397
 
398
+ run_button = gr.Button("Run Simulation", variant="primary")
399
 
400
  with gr.Column():
401
  gr.Markdown("### Simulation Results")
 
432
  ],
433
  outputs=[status_output, results_output, swr_plot_output, paths_plot_output]
434
  )
435
+
436
+ # Define default parameters for initial load
437
+ DEFAULT_PARAMS = {
438
+ "initial_investment": 1_000_000.0,
439
+ "num_years": 30,
440
+ "num_simulations": 5000,
441
+ "target_success_rate": 95,
442
+ "stock_mean_return": 9.0,
443
+ "stock_std_dev": 15.0,
444
+ "bond_mean_return": 4.0,
445
+ "bond_std_dev": 5.0,
446
+ "stock_allocation": 60,
447
+ "correlation_stock_bond": -0.2,
448
+ "mean_inflation": 2.5,
449
+ "std_dev_inflation": 1.5,
450
+ "min_swr_test": 2.5,
451
+ "max_swr_test": 5.0,
452
+ "num_swr_intervals": 15
453
+ }
454
+
455
+ def load_default_results():
456
+ """Loads cached results for default parameters if available."""
457
+ # Ensure the order of parameters matches _generate_cache_key
458
+ default_key = _generate_cache_key(
459
+ DEFAULT_PARAMS["initial_investment"],
460
+ DEFAULT_PARAMS["num_years"],
461
+ DEFAULT_PARAMS["num_simulations"],
462
+ DEFAULT_PARAMS["target_success_rate"],
463
+ DEFAULT_PARAMS["stock_mean_return"],
464
+ DEFAULT_PARAMS["stock_std_dev"],
465
+ DEFAULT_PARAMS["bond_mean_return"],
466
+ DEFAULT_PARAMS["bond_std_dev"],
467
+ DEFAULT_PARAMS["stock_allocation"],
468
+ DEFAULT_PARAMS["correlation_stock_bond"],
469
+ DEFAULT_PARAMS["mean_inflation"],
470
+ DEFAULT_PARAMS["std_dev_inflation"],
471
+ DEFAULT_PARAMS["min_swr_test"],
472
+ DEFAULT_PARAMS["max_swr_test"],
473
+ DEFAULT_PARAMS["num_swr_intervals"]
474
+ )
475
+
476
+ cached = _load_from_cache(default_key)
477
+ if cached:
478
+ return cached
479
+ else:
480
+ # Return empty/placeholder values if no cache hit
481
+ return "No default cache found. Run simulation.", "", None, None
482
+
483
+ # Load default results on app startup
484
+ demo.load(
485
+ fn=load_default_results,
486
+ outputs=[status_output, results_output, swr_plot_output, paths_plot_output]
487
+ )
488
+
489
  demo.launch()