qgallouedec HF staff commited on
Commit
f8eee5a
·
verified ·
1 Parent(s): db805e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -18
app.py CHANGED
@@ -1,47 +1,64 @@
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
 
 
 
4
 
5
  def plot_forecast(num_param, batch_size, precision, seq_len):
6
  # Convert number (input as B)
7
  num_param = float(num_param) * 1e9
8
-
9
  # Convert precision to bytes
10
  precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]
11
 
12
  # Model Parameters: N×precision
13
- y1 = num_param * precision / (1024**3)
14
-
15
- # Activations: B×Sequence Length×K×precision
16
- K = 4.6894e-04 * num_param + 1.8494e06
17
- y2 = batch_size * seq_len * K * precision / (1024**3)
18
 
19
  # Optimizer States: 2×N×precision
20
- y3 = 2 * num_param * precision / (1024**3)
 
 
 
 
 
21
 
22
  # Gradients: N×precision
23
- y4 = num_param * 4 / (1024**3)
 
 
 
 
 
 
24
 
25
  fig = plt.figure(figsize=(4, 4))
26
  ax = fig.add_subplot(111)
27
 
28
  # Create stacked bars
29
- ax.bar(0, y1, color="r")
30
- ax.bar(0, y2, bottom=y1, color="b")
31
- ax.bar(0, y3, bottom=y1 + y2, color="g")
32
- ax.bar(0, y4, bottom=y1 + y2 + y3, color="y")
 
 
33
 
34
  # Add text labels inside the bars
35
- ax.text(0, y1 / 2, "Model Parameters", ha="center", va="center", color="white", fontweight="bold")
36
- ax.text(0, y1 + y2 / 2,"Activations", ha="center", va="center", color="white", fontweight="bold")
37
- ax.text(0, y1 + y2 + y3 / 2, "Optimizer States", ha="center", va="center", color="white", fontweight="bold")
38
- ax.text(0, y1 + y2 + y3 + y4 / 2, "Gradients", ha="center", va="center", color="white", fontweight="bold")
 
39
 
40
- # remove x axis
 
 
 
41
  ax.xaxis.set_visible(False)
42
 
43
  # Set GB as the unit for the y-axis
44
  ax.set_ylabel("Memory (GB)")
 
 
45
  fig.tight_layout()
46
  return fig
47
 
@@ -52,7 +69,7 @@ demo = gr.Interface(
52
  gr.Number(7, label="Number of parameters (B)"),
53
  gr.Radio([1, 2, 4, 8, 16, 32, 64, 128], value=8, label="Batch size"),
54
  gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision"),
55
- gr.Slider(1, 1024, label="Sequence Length", step=1, value=128),
56
  ],
57
  gr.Plot(label="forecast", format="png"),
58
  )
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
 
4
+ import matplotlib.pyplot as plt
5
+
6
 
7
  def plot_forecast(num_param, batch_size, precision, seq_len):
8
  # Convert number (input as B)
9
  num_param = float(num_param) * 1e9
10
+
11
  # Convert precision to bytes
12
  precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]
13
 
14
  # Model Parameters: N×precision
15
+ y1 = num_param * precision / (1000**3)
 
 
 
 
16
 
17
  # Optimizer States: 2×N×precision
18
+ y2 = 2 * num_param * precision / (1000**3)
19
+
20
+ # Activations: B×Sequence Length×K×precision
21
+ K = 4.6894e-4 * num_param + 1.8494e6
22
+ print(K)
23
+ y3 = batch_size * seq_len * K * precision / (1000**3)
24
 
25
  # Gradients: N×precision
26
+ y4 = num_param * precision / (1000**3)
27
+
28
+ # Optimizer intermediates: N×precision
29
+ y5 = num_param * precision / (1000**3)
30
+
31
+ # Calculate total memory
32
+ total_memory = y1 + y2 + max(y3, y4 + y5)
33
 
34
  fig = plt.figure(figsize=(4, 4))
35
  ax = fig.add_subplot(111)
36
 
37
  # Create stacked bars
38
+ bar_width = 0.5
39
+ ax.bar(0, y1, width=bar_width, color="r")
40
+ ax.bar(0, y2, bottom=y1, width=bar_width, color="b")
41
+ ax.bar(-bar_width / 4, y3, bottom=y1 + y2, width=bar_width / 2, color="g")
42
+ ax.bar(bar_width / 4, y4, bottom=y1 + y2, width=bar_width / 2, color="y")
43
+ ax.bar(bar_width / 4, y5, bottom=y1 + y2 + y4, width=bar_width / 2, color="c")
44
 
45
  # Add text labels inside the bars
46
+ ax.text(0, y1 / 2, f"Model Parameters ({y1:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
47
+ ax.text(0, y1 + y2 / 2, f"Optimizer States ({y2:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
48
+ ax.text(-bar_width / 4, y1 + y2 + y3 / 2, f"Activations\n({y3:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
49
+ ax.text(bar_width / 4, y1 + y2 + y4 / 2, f"Gradients\n({y4:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
50
+ ax.text(bar_width / 4, y1 + y2 + y4 + y5 / 2, f"Optimizer\nintermediates\n({y5:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
51
 
52
+ # Or as title
53
+ ax.set_title(f"Total Memory: {total_memory:.1f} GB", fontweight="bold")
54
+
55
+ # Remove x-axis
56
  ax.xaxis.set_visible(False)
57
 
58
  # Set GB as the unit for the y-axis
59
  ax.set_ylabel("Memory (GB)")
60
+
61
+ # Adjust layout
62
  fig.tight_layout()
63
  return fig
64
 
 
69
  gr.Number(7, label="Number of parameters (B)"),
70
  gr.Radio([1, 2, 4, 8, 16, 32, 64, 128], value=8, label="Batch size"),
71
  gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision"),
72
+ gr.Slider(1, 1000, label="Sequence Length", step=1, value=128),
73
  ],
74
  gr.Plot(label="forecast", format="png"),
75
  )