import math from dataclasses import dataclass import torch import torch.distributed as dist from torch.distributed._tensor import DTensor, Replicate, Shard # This code snippet is a modified version adapted from the following GitHub repositories: # https://github.com/KellerJordan/Muon/blob/master/muon.py @torch.no_grad() def _zeropower_via_newtonschulz5(G, steps): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no longer converges all the way to one everywhere on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G # no manual typecast if G.size(0) > G.size(1): X = X.T # Ensure spectral norm is at most 1 X = X / (X.norm() + 1e-7) X = X.bfloat16() # Perform the NS iterations for _ in range(steps): A = X @ X.T # B = ( # b * A + c * A @ A # ) B = torch.addmm(A, A, A, alpha=c, beta=b) # X = a * X + B @ X X = torch.addmm(X, B, X, alpha=1.0, beta=a) if G.size(0) > G.size(1): X = X.T return X.to(G.dtype) @dataclass class _muon_state: # TODO: use Optional worker_rank: int | None = None gathered_grad: torch.Tensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None process_group = None @torch.no_grad() def _gather(p, state, rank, comm_stream, none_grad): g = p.grad if rank == state.worker_rank: num_ranks = dist.get_world_size(group=state.process_group) gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)] else: gather_list = None with torch.cuda.stream(comm_stream): torch.distributed.gather( g.to_local(), dst=state.worker_rank, gather_list=gather_list, group=state.process_group, ) if rank == state.worker_rank: if state.gathered_grad is not None: raise RuntimeError( "Gather event already exists, which should not happen." ) state.gathered_grad = torch.cat(gather_list, dim=0) state.gather_event = torch.cuda.Event() state.gather_event.record() else: state.gathered_grad = None state.gather_event = None if none_grad: p.grad = None @torch.no_grad() def _compute_u(state, steps, rank, compute_stream): with torch.cuda.stream(compute_stream): if rank == state.worker_rank: if state.gather_event is None: raise RuntimeError("Gather event must be set before compute.") compute_stream.wait_event(state.gather_event) u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) state.computed_u = u state.compute_event = torch.cuda.Event() state.compute_event.record() # Clear the gathered gradient to free memory state.gathered_grad = None else: state.computed_u = None state.compute_event = None @torch.no_grad() def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream): u = state.computed_u with torch.cuda.stream(comm_stream): if rank == state.worker_rank: num_ranks = dist.get_world_size(group=state.process_group) if state.compute_event is None: raise RuntimeError("Compute event must be set before scatter.") comm_stream.wait_event(state.compute_event) scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0)) else: scatter_list = None u = torch.empty_like(p.to_local()) torch.distributed.scatter( u, scatter_list=scatter_list, src=state.worker_rank, group=state.process_group, ) if rank == state.worker_rank: # Clear u to free memory state.computed_u = None u = DTensor.from_local( u, placements=p.placements, device_mesh=p.device_mesh, ) p.data.mul_(1 - lr * weight_decay) p.data.add_(u, alpha=-adjusted_lr) def default_is_muon(x, name): return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. Some warnings: - We believe this optimizer is unlikely to work well for training with small batch size. - We believe it may not work well for finetuning pretrained models, but we haven't tested this. Arguments: muon_params: The parameters to be optimized by Muon. lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. adamw_weight_decay: The weight decay for the internal AdamW. """ def __init__( self, model, is_muon_func=default_is_muon, lr=1e-3, momentum=0.95, nesterov=True, ns_steps=5, weight_decay=0.1, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, none_grad=True, debug=False, ): defaults = dict( lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, none_grad=none_grad, ) super().__init__(model.parameters(), defaults) self.is_muon_func = is_muon_func self.model = model if dist.is_initialized(): self.rank = dist.get_rank() else: self.rank = None self.comm_stream = torch.cuda.Stream() self.compute_stream = torch.cuda.Stream() self.debug = debug def __setstate__(self, state): # Sort parameters into those for which we will use Muon, and those for which we will not super().__setstate__(state) self._init_state() def _init_state(self): for name, p in self.model.named_parameters(): if self.is_muon_func(p, name): # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer assert p.ndim == 2, p.ndim self.state[p]["use_muon"] = True else: # Do not use Muon for parameters in adamw_params self.state[p]["use_muon"] = False def _calc_flops(self, G, steps): assert len(G.shape) == 2 M, N = G.shape if M > N: M, N = N, M return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) def adjust_lr_for_muon(self, lr, param_shape): A, B = param_shape[:2] # We adjust the learning rate and weight decay based on the size of the parameter matrix # as describted in the paper adjusted_ratio = 0.2 * math.sqrt(max(A, B)) adjusted_lr = lr * adjusted_ratio return adjusted_lr def get_shard_mesh(self, p, rank): """ Get the shard mesh for a parameter p on the given rank. """ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters." if p.placements == (Shard(dim=0),): # Case for FSDP return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) elif p.placements == (Replicate(), Shard(dim=0)): # Case for HSDP for i, shard_mesh in enumerate(p.device_mesh.mesh): if rank in shard_mesh: return shard_mesh, p.device_mesh.get_group(mesh_dim=1) else: raise ValueError(f"Unsupported placements ({p.placements}).") def init_state_and_assign_params(self, params, group): param_to_state = {} param_to_flops = {} total_flops = 0 for p in params: g = p.grad if g is None: continue assert g.ndim == 2, "Muon only supports 2D parameters." flops = self._calc_flops(g, group["ns_steps"]) param_to_flops[id(p)] = flops total_flops += flops if self.debug: print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) ordered_params = sorted( params, key=lambda p: param_to_flops[id(p)], reverse=True ) round_robin = 0 mesh = None shard_mesh = None process_group = None for p in ordered_params: if mesh is None: mesh = p.device_mesh shard_mesh, process_group = self.get_shard_mesh(p, self.rank) elif mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") param_to_state[id(p)] = _muon_state() param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item() param_to_state[id(p)].process_group = process_group round_robin = (round_robin + 1) % len(shard_mesh) return param_to_state, ordered_params def base(self, params, group, lr, weight_decay, momentum): # generate weight updates in distributed fashion for p in params: g = p.grad if g is None: continue if g.ndim > 2: g = g.view(g.size(0), -1) assert g is not None # calc update state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] buf.mul_(momentum).add_(g) if group["nesterov"]: g = g.add(buf, alpha=momentum) else: g = buf u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) # scale update adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) # apply weight decay p.data.mul_(1 - lr * weight_decay) # apply update p.data.add_(u, alpha=-adjusted_lr) def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] buf.mul_(momentum).add_(g) if group["nesterov"]: g = g.add(buf, alpha=momentum) else: g = buf return g def _update_p(self, p, u, lr, weight_decay): # scale update adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) # apply weight decay p.data.mul_(1 - lr * weight_decay) # apply update p.data.add_(u, alpha=-adjusted_lr) def parallel(self, params, group, lr, weight_decay, momentum): """ Perform a parallel optimization step using Muon. """ for p in params: g = p.grad if g is None: continue if g.ndim > 2: g = g.view(g.size(0), -1) # Update g in the local rank g = self._update_g( p, g, group, momentum=momentum, ) p.grad = g param_to_state, ordered_params = self.init_state_and_assign_params( params, group ) def enqueue_gathers(start_idx, chunk_size): for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] _gather(p, state, self.rank, self.comm_stream, group["none_grad"]) def enqueue_computes(start_idx, chunk_size): for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) def enqueue_scatters(start_idx, chunk_size): for p in ordered_params[start_idx : start_idx + chunk_size]: state = param_to_state[id(p)] adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) _scatter( p, state, lr, adjusted_lr, weight_decay, self.rank, self.comm_stream ) chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group) # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) enqueue_gathers(0, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_computes(i, chunk_size) enqueue_gathers(i + chunk_size, chunk_size) enqueue_scatters(i, chunk_size) torch.cuda.current_stream().wait_stream(self.comm_stream) def step(self, closure=None): """Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: ############################ # Muon # ############################ if "use_muon" not in self.state[group["params"][0]]: self._init_state() params = [p for p in group["params"] if self.state[p]["use_muon"]] lr = group["lr"] weight_decay = group["weight_decay"] momentum = group["momentum"] param_dtensors = [] param_tensors = [] for p in params: if p is None or p.grad is None: continue if isinstance(p.data, DTensor): if all( isinstance(placement, Replicate) for placement in p.placements ): param_tensors.append(p) else: param_dtensors.append(p) elif isinstance(p.data, torch.Tensor): param_tensors.append(p) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") if self.debug: print( f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", flush=True, ) if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( "Parallel Muon requires torch.distributed to be initialized." ) self.parallel( param_dtensors, group, lr=lr, weight_decay=weight_decay, momentum=momentum, ) if len(param_tensors) > 0: self.base( param_tensors, group, lr=lr, weight_decay=weight_decay, momentum=momentum, ) ############################ # AdamW backup # ############################ params = [p for p in group["params"] if not self.state[p]["use_muon"]] lr = group["lr"] beta1, beta2 = group["adamw_betas"] eps = group["adamw_eps"] weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = self.state[p] if "step" not in state: state["step"] = 0 state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) state["step"] += 1 step = state["step"] buf1 = state["moment1"] buf2 = state["moment2"] buf1.lerp_(g, 1 - beta1) buf2.lerp_(g.square(), 1 - beta2) g = buf1 / (eps + buf2.sqrt()) bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step scale = bias_correction1 / bias_correction2**0.5 p.data.mul_(1 - lr * weight_decay) p.data.add_(g, alpha=-lr / scale) return loss