feat(muon) : add tuned-abc-values & blfoat16 communication
Browse files- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +158 -50
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +158 -50
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +2 -2
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +158 -50
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +2 -2
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +158 -50
- torch-ext/optimizer/muon.py +158 -50
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787376
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:55f17ad6ecdd22d84ea5b776a317fa9fbb6b81f622fa8fc80b78e0ef80bd4ea6
|
| 3 |
size 1787376
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824264
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f37c80a535a081e997c1973902a010c48b33ca40085a7f267a5278e56cff26f3
|
| 3 |
size 1824264
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f8bf16b0ae5af74852e8c890183c8c32175886c3d0366cfc776fb3e1ee15906
|
| 3 |
+
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d50267ec23db9512ae1d82c99012901d58e50dee9bf34346702561a5d3e6d9e7
|
| 3 |
+
size 1749840
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824264
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:80ce6b0d62167a8ea10b6e2a1f90df70aa108997570c0ed210f458debd26f32f
|
| 3 |
size 1824264
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:22dc3ab77ab74837126281f79f417c5d55b2cc9885388fd9d3a1c7c824ece2bd
|
| 3 |
-
size 1883360
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3487612a8f022a1df1353945fc6d65bbd6797179b06c5d3202dc6e2aa6afb27a
|
| 3 |
+
size 1883352
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:62ecfc7e6a1ab0c4ada19ed7aea40fc0a431c4ceb1729666efa98ac0e407f9c8
|
| 3 |
-
size 1883360
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5e375def39d93758b60534cef504ae75d9c13e0d86da5dcf7642f1f90b77f52
|
| 3 |
+
size 1883352
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:37e389c650fc1fcbc9fbd68f1e7c1a768b08e90509fd8a5d87879655726f2db2
|
| 3 |
-
size 1750040
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33e0d50fbf340612b0e1129717e4116197c8562592e5920f2dedc718ce9a0585
|
| 3 |
+
size 1750000
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_ee6ed44_dirty
|
| 3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e62682b711f002505bb17c170b2bb233f8d389510ff8e2e0a753ee96d11d0746
|
| 3 |
-
size 1750128
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5eedf56e661a7d314727e40f192236dbd9696f62ba21f11e366643f2662c03a4
|
| 3 |
+
size 1750088
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
|
|
| 15 |
@torch.no_grad()
|
| 16 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 17 |
"""
|
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 25 |
"""
|
| 26 |
assert len(G.shape) == 2
|
| 27 |
-
|
| 28 |
X = G # no manual typecast
|
|
|
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
# Ensure spectral norm is at most 1
|
| 32 |
X = X / (X.norm() + 1e-7)
|
| 33 |
-
X = X.bfloat16()
|
| 34 |
# Perform the NS iterations
|
| 35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
A = X @ X.T
|
| 37 |
# B = (
|
| 38 |
# b * A + c * A @ A
|
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 43 |
|
| 44 |
if G.size(0) > G.size(1):
|
| 45 |
X = X.T
|
| 46 |
-
return X
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 65 |
Gather the gradients to worker_rank.
|
| 66 |
If none_grad is True, free p.grad after the gather.
|
| 67 |
"""
|
| 68 |
-
g = p.grad
|
| 69 |
-
|
| 70 |
-
if rank == state.worker_rank:
|
| 71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
| 72 |
-
gather_list = [
|
| 73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 74 |
-
]
|
| 75 |
-
else:
|
| 76 |
-
gather_list = None
|
| 77 |
-
|
| 78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
torch.distributed.gather(
|
| 80 |
g.to_local(),
|
| 81 |
dst=state.worker_rank,
|
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 92 |
else:
|
| 93 |
state.gathered_grad = None
|
| 94 |
state.gather_event = None
|
|
|
|
| 95 |
if none_grad:
|
| 96 |
# We can safely free p.grad without calling record_stream:
|
| 97 |
# p.grad.to_local().record_stream(comm_stream)
|
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
| 108 |
"""
|
| 109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 110 |
"""
|
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 115 |
compute_stream.wait_event(state.gather_event)
|
| 116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 117 |
state.computed_u = u
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
|
| 125 |
@torch.no_grad()
|
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if rank == state.worker_rank:
|
| 133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 134 |
-
if state.compute_event is None:
|
| 135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
| 136 |
-
comm_stream.wait_event(state.compute_event)
|
| 137 |
-
|
| 138 |
# Clear the gathered gradient to free memory
|
| 139 |
state.gathered_grad = None
|
| 140 |
|
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 144 |
else:
|
| 145 |
scatter_list = None
|
| 146 |
|
| 147 |
-
u_received = torch.empty_like(p.to_local())
|
| 148 |
torch.distributed.scatter(
|
| 149 |
-
|
| 150 |
scatter_list=scatter_list,
|
| 151 |
src=state.worker_rank,
|
| 152 |
group=state.process_group,
|
| 153 |
)
|
| 154 |
-
u_dtensor = DTensor.from_local(
|
| 155 |
-
u_received,
|
| 156 |
-
placements=p.placements,
|
| 157 |
-
device_mesh=p.device_mesh,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
state.scattered_u = u_dtensor
|
| 161 |
state.scatter_event = torch.cuda.Event()
|
| 162 |
state.scatter_event.record()
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
| 172 |
if state.scatter_event is None:
|
| 173 |
raise RuntimeError("Scatter event must be set before update")
|
| 174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if rank == state.worker_rank:
|
| 176 |
# Free computed_u
|
| 177 |
state.computed_u = None
|
| 178 |
|
| 179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
def default_is_muon(name, x):
|
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 375 |
else:
|
| 376 |
g = buf
|
| 377 |
|
| 378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
|
| 379 |
|
| 380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 433 |
def enqueue_computes(start_idx, chunk_size):
|
| 434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 435 |
state = param_to_state[id(p)]
|
| 436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
| 437 |
self.compute_stream)
|
| 438 |
|
| 439 |
def enqueue_scatters(start_idx, chunk_size):
|
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
| 466 |
# Wait the last update_param to finish
|
| 467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
def step(self, closure=None):
|
| 470 |
"""Perform a single optimization step.
|
| 471 |
|
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 542 |
# AdamW backup #
|
| 543 |
############################
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
lr = group["lr"]
|
| 546 |
beta1, beta2 = group["adamw_betas"]
|
| 547 |
eps = group["adamw_eps"]
|
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
| 552 |
if g is None:
|
| 553 |
continue
|
| 554 |
state = self.state[p]
|
|
|
|
|
|
|
| 555 |
if "step" not in state:
|
| 556 |
-
state["step"] =
|
|
|
|
|
|
|
| 557 |
state["moment1"] = torch.zeros_like(g)
|
| 558 |
state["moment2"] = torch.zeros_like(g)
|
| 559 |
-
state["
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
return loss
|
|
|
|
| 2 |
import math
|
| 3 |
import types
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Union, cast
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.distributed as dist
|
|
|
|
| 13 |
|
| 14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
| 16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
|
|
|
| 27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
+
assert G.dtype == torch.bfloat16
|
| 31 |
X = G # no manual typecast
|
| 32 |
+
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
+
for a, b, c in [
|
| 39 |
+
(4.0848, -6.8946, 2.9270),
|
| 40 |
+
(3.9505, -6.3029, 2.6377),
|
| 41 |
+
(3.7418, -5.5913, 2.3037),
|
| 42 |
+
(2.8769, -3.1427, 1.2046),
|
| 43 |
+
(2.8366, -3.0525, 1.2012),
|
| 44 |
+
]:
|
| 45 |
A = X @ X.T
|
| 46 |
# B = (
|
| 47 |
# b * A + c * A @ A
|
|
|
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
| 55 |
+
return X
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
|
|
| 74 |
Gather the gradients to worker_rank.
|
| 75 |
If none_grad is True, free p.grad after the gather.
|
| 76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.cuda.stream(comm_stream):
|
| 78 |
+
g = p.grad
|
| 79 |
+
|
| 80 |
+
if rank == state.worker_rank:
|
| 81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 82 |
+
gather_list = [
|
| 83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
| 84 |
+
for _ in range(num_ranks)
|
| 85 |
+
]
|
| 86 |
+
else:
|
| 87 |
+
gather_list = None
|
| 88 |
+
|
| 89 |
+
g = g.to(torch.bfloat16)
|
| 90 |
torch.distributed.gather(
|
| 91 |
g.to_local(),
|
| 92 |
dst=state.worker_rank,
|
|
|
|
| 103 |
else:
|
| 104 |
state.gathered_grad = None
|
| 105 |
state.gather_event = None
|
| 106 |
+
gather_list = None
|
| 107 |
if none_grad:
|
| 108 |
# We can safely free p.grad without calling record_stream:
|
| 109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
@torch.no_grad()
|
| 119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
| 120 |
"""
|
| 121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
| 122 |
"""
|
|
|
|
| 127 |
compute_stream.wait_event(state.gather_event)
|
| 128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 129 |
state.computed_u = u
|
| 130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 131 |
+
dtype=torch.bfloat16)
|
| 132 |
+
state.compute_event = torch.cuda.Event()
|
| 133 |
+
state.compute_event.record()
|
| 134 |
+
u = None
|
| 135 |
|
| 136 |
|
| 137 |
@torch.no_grad()
|
|
|
|
| 141 |
"""
|
| 142 |
|
| 143 |
with torch.cuda.stream(comm_stream):
|
| 144 |
+
if state.compute_event is None:
|
| 145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
| 146 |
+
comm_stream.wait_event(state.compute_event)
|
| 147 |
+
|
| 148 |
if rank == state.worker_rank:
|
| 149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Clear the gathered gradient to free memory
|
| 151 |
state.gathered_grad = None
|
| 152 |
|
|
|
|
| 156 |
else:
|
| 157 |
scatter_list = None
|
| 158 |
|
|
|
|
| 159 |
torch.distributed.scatter(
|
| 160 |
+
state.scattered_u,
|
| 161 |
scatter_list=scatter_list,
|
| 162 |
src=state.worker_rank,
|
| 163 |
group=state.process_group,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state.scatter_event = torch.cuda.Event()
|
| 166 |
state.scatter_event.record()
|
| 167 |
+
scatter_list = None
|
| 168 |
|
| 169 |
|
| 170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 177 |
if state.scatter_event is None:
|
| 178 |
raise RuntimeError("Scatter event must be set before update")
|
| 179 |
compute_stream.wait_event(state.scatter_event)
|
| 180 |
+
u_dtensor = DTensor.from_local(
|
| 181 |
+
state.scattered_u,
|
| 182 |
+
placements=p.placements,
|
| 183 |
+
device_mesh=p.device_mesh,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
state.scattered_u = u_dtensor
|
| 187 |
+
|
| 188 |
if rank == state.worker_rank:
|
| 189 |
# Free computed_u
|
| 190 |
state.computed_u = None
|
| 191 |
|
| 192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
| 193 |
+
state.scattered_u = None
|
| 194 |
+
u_dtensor = None
|
| 195 |
|
| 196 |
|
| 197 |
def default_is_muon(name, x):
|
|
|
|
| 390 |
else:
|
| 391 |
g = buf
|
| 392 |
|
| 393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
| 394 |
+
steps=group["ns_steps"])
|
| 395 |
|
| 396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
|
| 449 |
def enqueue_computes(start_idx, chunk_size):
|
| 450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 451 |
state = param_to_state[id(p)]
|
| 452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 453 |
self.compute_stream)
|
| 454 |
|
| 455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
|
| 482 |
# Wait the last update_param to finish
|
| 483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
| 484 |
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _fused_adamw(
|
| 487 |
+
params: list[torch.Tensor],
|
| 488 |
+
grads: list[torch.Tensor],
|
| 489 |
+
exp_avgs: list[torch.Tensor],
|
| 490 |
+
exp_avg_sqs: list[torch.Tensor],
|
| 491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
| 492 |
+
state_steps: list[torch.Tensor],
|
| 493 |
+
amsgrad: bool,
|
| 494 |
+
beta1: float,
|
| 495 |
+
beta2: float,
|
| 496 |
+
lr: Union[float, torch.Tensor],
|
| 497 |
+
weight_decay: float,
|
| 498 |
+
eps: float,
|
| 499 |
+
maximize: bool,
|
| 500 |
+
) -> None:
|
| 501 |
+
if not params:
|
| 502 |
+
return
|
| 503 |
+
|
| 504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
| 505 |
+
# treating it as a scalar.
|
| 506 |
+
lr_dict: Optional[DeviceDict] = ({
|
| 507 |
+
lr.device: lr
|
| 508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
| 509 |
+
None)
|
| 510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
| 511 |
+
[
|
| 512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
| 513 |
+
state_steps
|
| 514 |
+
] # type: ignore[list-item]
|
| 515 |
+
)
|
| 516 |
+
for (device, _), (
|
| 517 |
+
(
|
| 518 |
+
device_params_,
|
| 519 |
+
device_grads_,
|
| 520 |
+
device_exp_avgs_,
|
| 521 |
+
device_exp_avg_sqs_,
|
| 522 |
+
device_max_exp_avg_sqs,
|
| 523 |
+
device_state_steps_,
|
| 524 |
+
),
|
| 525 |
+
_,
|
| 526 |
+
) in grouped_tensors.items():
|
| 527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
| 528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
| 529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
| 530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
| 531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
| 532 |
+
|
| 533 |
+
if lr_dict is not None and device not in lr_dict:
|
| 534 |
+
lr_dict[device] = lr.to(
|
| 535 |
+
device=device,
|
| 536 |
+
non_blocking=True) # type: ignore[union-attr]
|
| 537 |
+
lr = lr_dict[device]
|
| 538 |
+
torch._foreach_add_(device_state_steps, 1)
|
| 539 |
+
func = torch._fused_adamw_
|
| 540 |
+
func(
|
| 541 |
+
device_params,
|
| 542 |
+
device_grads,
|
| 543 |
+
device_exp_avgs,
|
| 544 |
+
device_exp_avg_sqs,
|
| 545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
| 546 |
+
device_state_steps,
|
| 547 |
+
amsgrad=amsgrad,
|
| 548 |
+
lr=lr, # type: ignore[arg-type]
|
| 549 |
+
beta1=beta1,
|
| 550 |
+
beta2=beta2,
|
| 551 |
+
weight_decay=weight_decay,
|
| 552 |
+
eps=eps,
|
| 553 |
+
maximize=maximize,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
def step(self, closure=None):
|
| 557 |
"""Perform a single optimization step.
|
| 558 |
|
|
|
|
| 629 |
# AdamW backup #
|
| 630 |
############################
|
| 631 |
|
| 632 |
+
params_with_grads = []
|
| 633 |
+
grads = []
|
| 634 |
+
moment1 = []
|
| 635 |
+
moment2 = []
|
| 636 |
+
max_exp_avg_sqs = []
|
| 637 |
+
state_steps = []
|
| 638 |
lr = group["lr"]
|
| 639 |
beta1, beta2 = group["adamw_betas"]
|
| 640 |
eps = group["adamw_eps"]
|
|
|
|
| 645 |
if g is None:
|
| 646 |
continue
|
| 647 |
state = self.state[p]
|
| 648 |
+
params_with_grads.append(p)
|
| 649 |
+
grads.append(g)
|
| 650 |
if "step" not in state:
|
| 651 |
+
state["step"] = (torch.zeros((),
|
| 652 |
+
dtype=torch.float32,
|
| 653 |
+
device=p.device))
|
| 654 |
state["moment1"] = torch.zeros_like(g)
|
| 655 |
state["moment2"] = torch.zeros_like(g)
|
| 656 |
+
moment1.append(state["moment1"])
|
| 657 |
+
moment2.append(state["moment2"])
|
| 658 |
+
if not isinstance(state["step"], torch.Tensor):
|
| 659 |
+
step_tensor = torch.tensor(state["step"],
|
| 660 |
+
dtype=torch.float32,
|
| 661 |
+
device=p.device)
|
| 662 |
+
else:
|
| 663 |
+
step_tensor = state["step"]
|
| 664 |
+
state_steps.append(step_tensor)
|
| 665 |
+
|
| 666 |
+
self._fused_adamw(
|
| 667 |
+
params_with_grads,
|
| 668 |
+
grads,
|
| 669 |
+
moment1,
|
| 670 |
+
moment2,
|
| 671 |
+
max_exp_avg_sqs,
|
| 672 |
+
state_steps,
|
| 673 |
+
amsgrad=False,
|
| 674 |
+
beta1=beta1,
|
| 675 |
+
beta2=beta2,
|
| 676 |
+
lr=lr,
|
| 677 |
+
weight_decay=weight_decay,
|
| 678 |
+
eps=eps,
|
| 679 |
+
maximize=False,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
return loss
|