File size: 14,899 Bytes
bec1e88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
from collections import namedtuple
from datetime import datetime
from typing import Any
import torch
from torch.utils.tensorboard import SummaryWriter
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.tools import utils
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import Color, device_module, device_type
# named tuple for passing device memory stats for logging
DeviceMemStats = namedtuple(
"DeviceMemStats",
[
"max_active_gib",
"max_active_pct",
"max_reserved_gib",
"max_reserved_pct",
"num_alloc_retries",
"num_ooms",
],
)
class DeviceMemoryMonitor:
def __init__(self, device: str = f"{device_type}:0"):
self.device = torch.device(device) # device object
self.device_name = device_module.get_device_name(self.device)
self.device_index = device_module.current_device()
self.device_capacity = device_module.get_device_properties(
self.device
).total_memory
self.device_capacity_gib = self._to_gib(self.device_capacity)
device_module.reset_peak_memory_stats()
device_module.empty_cache()
def _to_gib(self, memory_in_bytes):
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
memory_in_gib = memory_in_bytes / _gib_in_bytes
return memory_in_gib
def _to_pct(self, memory):
return 100 * memory / self.device_capacity
def get_peak_stats(self):
device_info = device_module.memory_stats(self.device)
max_active = device_info.get("active_bytes.all.peak", -1)
max_active_gib = self._to_gib(max_active)
max_active_pct = self._to_pct(max_active)
max_reserved = device_info.get("reserved_bytes.all.peak", -1)
max_reserved_gib = self._to_gib(max_reserved)
max_reserved_pct = self._to_pct(max_reserved)
num_retries = device_info.get("num_alloc_retries", -1)
num_ooms = device_info.get("num_ooms", -1)
if num_retries > 0:
logger.warning(
f"{num_retries} {device_type.upper()} memory allocation retries."
)
if num_ooms > 0:
logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.")
return DeviceMemStats(
max_active_gib,
max_active_pct,
max_reserved_gib,
max_reserved_pct,
num_retries,
num_ooms,
)
def reset_peak_stats(self):
device_module.reset_peak_memory_stats()
def build_device_memory_monitor():
device_memory_monitor = DeviceMemoryMonitor(device_type)
logger.info(
f"{device_type.upper()} capacity: {device_memory_monitor.device_name} "
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
)
return device_memory_monitor
class BaseLogger:
"""Logger that does nothing, used when logging is disabled."""
def log(self, metrics: dict[str, Any], step: int) -> None:
pass
def close(self) -> None:
pass
class TensorBoardLogger(BaseLogger):
"""Logger implementation for TensorBoard."""
def __init__(self, log_dir: str, tag: str | None = None):
self.tag = tag
self.writer = SummaryWriter(log_dir, max_queue=1000)
logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}")
def log(self, metrics: dict[str, Any], step: int) -> None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)
def close(self) -> None:
self.writer.close()
class WandBLogger(BaseLogger):
"""Logger implementation for Weights & Biases."""
def __init__(self, log_dir: str, tag: str | None = None):
# Import wandb here to avoid startup import
import wandb
self.wandb = wandb
self.tag = tag
# Create logging directory
os.makedirs(log_dir, exist_ok=True)
self.wandb.init(
project=os.getenv("WANDB_PROJECT", "torchtitan"),
dir=log_dir,
)
logger.info("WandB logging enabled")
def log(self, metrics: dict[str, Any], step: int) -> None:
wandb_metrics = {
(k if self.tag is None else f"{self.tag}/{k}"): v
for k, v in metrics.items()
}
self.wandb.log(wandb_metrics, step=step)
def close(self) -> None:
if self.wandb.run is not None:
self.wandb.finish()
def ensure_pp_loss_visible(
parallel_dims: ParallelDims, job_config: JobConfig, color: Color
) -> None:
"""
Ensures that the loss is visible on the console for pipeline-parallel training.
For pipeline-parallel training, the loss is only visible on the last pipeline stage.
This function checks if the appropriate rank is included in the LOG_RANK environment
variable and warns if it's not.
"""
# V Block Schedules return loss on rank 0
if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
return
# Calculate the rank where loss is visible (first rank of the last pipeline stage)
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
loss_visible_rank = (world_size // pp_size) * (pp_size - 1)
# Check if the loss-visible rank is included in LOG_RANK environment variable
env_logged_ranks = os.environ.get("LOG_RANK", "").split(",")
if env_logged_ranks == [""]:
env_logged_ranks = []
if str(loss_visible_rank) not in env_logged_ranks:
logger.warning(
f"{color.red}Pipeline Parallel loss is not visible. "
f"Please add {color.yellow}rank {loss_visible_rank}{color.red} "
f"to LOG_RANK environment variable in run_train.sh.{color.reset}"
)
def _get_metrics_rank(
parallel_dims: ParallelDims,
job_config: JobConfig,
) -> int:
"""
Determines which rank should log metrics.
Returns:
int: The rank responsible for logging metrics:
- Rank 0 for non-pipeline-parallel configs
- Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule
- The first rank of the last pipeline stage for other pipeline-parallel schedules
"""
# Early return for non-pipeline-parallel configurations
if not parallel_dims.pp_enabled:
return 0
# V Block Schedules return loss on rank 0
if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
return 0
# Calculate first rank of the last pipeline stage
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
return (world_size // pp_size) * (pp_size - 1)
def _build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
) -> BaseLogger:
"""
Build an appropriate metric logger based on configuration.
"""
metrics_config = job_config.metrics
# Log initial config state
logger.debug(
f"Building logger with config: wandb={metrics_config.enable_wandb}, "
f"tensorboard={metrics_config.enable_tensorboard}"
)
# Check if any logging backend is enabled
has_logging_enabled = (
metrics_config.enable_tensorboard or metrics_config.enable_wandb
)
# Determine if this rank should log
should_log = has_logging_enabled
if (not metrics_config.save_for_all_ranks) and should_log:
metrics_rank = _get_metrics_rank(parallel_dims, job_config)
should_log = torch.distributed.get_rank() == metrics_rank
logger.debug(
f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}"
)
if not should_log:
logger.debug("Returning BaseLogger due to should_log=False")
return BaseLogger()
# Setup logging directory
dump_dir = job_config.job.dump_folder
base_log_dir = os.path.join(
dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
)
if metrics_config.save_for_all_ranks:
base_log_dir = os.path.join(
base_log_dir, f"rank_{torch.distributed.get_rank()}"
)
# Create loggers in priority order
if metrics_config.enable_wandb:
logger.debug("Attempting to create WandB logger")
try:
return WandBLogger(base_log_dir, tag)
except Exception as e:
if "No module named 'wandb'" in str(e):
logger.error(
"Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'."
)
else:
logger.error(f"Failed to create WandB logger: {e}")
if metrics_config.enable_tensorboard:
logger.debug("Creating TensorBoard logger")
return TensorBoardLogger(base_log_dir, tag)
logger.debug("No loggers enabled, returning BaseLogger")
return BaseLogger()
class MetricsProcessor:
"""Metrics processor to processes the metrics and log metrics.
The current MetricsProcessor log some metrics to STDOUT and some metrics to
TensorBoard or WandB.
Args:
job_config (JobConfig): Job configuration.
parallel_dims (ParallelDims): Parallel dimensions.
tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
"""
logger: BaseLogger
parallel_dims: ParallelDims
job_config: JobConfig
device_memory_monitor: DeviceMemoryMonitor
color: utils.NoColor | utils.Color
gpu_peak_flops: int
ntokens_since_last_log: int
data_loading_times: list[float]
time_last_log: float
num_flops_per_token: int
optimizers: OptimizersContainer | None
lr_schedulers: LRSchedulersContainer | None
def __init__(
self,
job_config: JobConfig,
parallel_dims: ParallelDims,
tag: str | None = None,
):
self.logger = _build_metric_logger(job_config, parallel_dims, tag)
self.parallel_dims = parallel_dims
self.job_config = job_config
self.device_memory_monitor = build_device_memory_monitor()
# used for colorful printing
self.color = (
utils.NoColor()
if job_config.metrics.disable_color_printing
else utils.Color()
)
self.gpu_peak_flops = utils.get_peak_flops(
self.device_memory_monitor.device_name
)
self.ntokens_since_last_log = 0
self.data_loading_times = []
self.time_last_log = time.perf_counter()
self.device_memory_monitor.reset_peak_stats()
# These variables have to be set later as they depend on other components or model.
self.num_flops_per_token = -1
self.optimizers = None
self.lr_schedulers = None
def should_log(self, step: int) -> bool:
return step == 1 or step % self.job_config.metrics.log_freq == 0
def log(
self,
step: int,
global_avg_loss: float,
global_max_loss: float,
extra_metrics: dict[str, Any] | None = None,
):
assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
time_delta = time.perf_counter() - self.time_last_log
# tokens per second per device, abbreviated as tps
tps = self.ntokens_since_last_log / (
time_delta * self.parallel_dims.non_data_parallel_size
)
# model FLOPS utilization
# For its definition and calculation, please refer to the PaLM paper:
# https://arxiv.org/abs/2204.02311
mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
tflops = self.num_flops_per_token * tps / 1e12
time_end_to_end = time_delta / self.job_config.metrics.log_freq
time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times)
time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta
device_mem_stats = self.device_memory_monitor.get_peak_stats()
metrics = {
"loss_metrics/global_avg_loss": global_avg_loss,
"loss_metrics/global_max_loss": global_max_loss,
"throughput(tps)": tps,
"tflops": tflops,
"mfu(%)": mfu,
"time_metrics/end_to_end(s)": time_end_to_end,
"time_metrics/data_loading(s)": time_data_loading,
"time_metrics/data_loading(%)": time_data_loading_pct,
"memory/max_active(GiB)": device_mem_stats.max_active_gib,
"memory/max_active(%)": device_mem_stats.max_active_pct,
"memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
"memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
"memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
"memory/num_ooms": device_mem_stats.num_ooms,
}
if extra_metrics:
metrics.update(extra_metrics)
self.logger.log(metrics, step)
color = self.color
construct_string = str(
f"{color.red}step: {step:2} "
f"{color.green}loss: {global_avg_loss:7.4f} "
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
f"({device_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}tps: {round(tps):,} "
f"{color.cyan}tflops: {tflops:,.2f} "
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
)
if extra_metrics:
for k, v in extra_metrics.items():
if "loss" in k:
construct_string += f" {color.white}{k.lstrip('loss_metrics/')}: {v:7.4f}"
logger.info(
construct_string
)
self.ntokens_since_last_log = 0
self.data_loading_times.clear()
self.time_last_log = time.perf_counter()
self.device_memory_monitor.reset_peak_stats()
def close(self):
self.logger.close()
def build_metrics_processor(
job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
) -> MetricsProcessor:
"""Create a metrics processor.
Args:
job_config (JobConfig): Job configuration.
parallel_dims (ParallelDims): Parallel dimensions.
tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
Returns:
MetricsProcessor: A metrics processor.
"""
return MetricsProcessor(job_config, parallel_dims, tag)
|