diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c4cb545ce9c9bbbff6607aa2f92b4a57a0851402 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.json filter=lfs diff=lfs merge=lfs -text diff --git a/both/checkpoint-1000/config.json b/both/checkpoint-1000/config.json new file mode 100644 index 0000000000000000000000000000000000000000..da6f9ba3622f70e966728e9939da080ce8df3df3 --- /dev/null +++ b/both/checkpoint-1000/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddcfa3c8f7da780e23fc2990aba55eef954aa0eb9e934a58c5536953e79a2b1c +size 729 diff --git a/both/checkpoint-1000/generation_config.json b/both/checkpoint-1000/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e0796b3dbfe28ebffe364cefdfc7f7116df2fae4 --- /dev/null +++ b/both/checkpoint-1000/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f312a9b22d84af6aa4883b93869f2f480a4cdc78a54cf6e7d5b2069f019ff4c4 +size 194 diff --git a/both/checkpoint-1000/latest b/both/checkpoint-1000/latest new file mode 100644 index 0000000000000000000000000000000000000000..e2d3435fb1acf8913e6bd6c51b01adfc69b11ac6 --- /dev/null +++ b/both/checkpoint-1000/latest @@ -0,0 +1 @@ +global_step1000 \ No newline at end of file diff --git a/both/checkpoint-1000/model-00001-of-00004.safetensors b/both/checkpoint-1000/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..84ae69a3576316d2db64275e793ffe261649525c --- /dev/null +++ b/both/checkpoint-1000/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c88653d5f23979af6885177d157364dd12bc73ed4b31b60b9e26c9c0e5b5e19e +size 4976698672 diff --git a/both/checkpoint-1000/model-00002-of-00004.safetensors b/both/checkpoint-1000/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..2bf208ac3d07fb8782ecc18b137da1d9d04df072 --- /dev/null +++ b/both/checkpoint-1000/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1da3b07882bad2e10a869443d6d986ead40981e2b6e410225bd1e0766300ea41 +size 4999802720 diff --git a/both/checkpoint-1000/model-00003-of-00004.safetensors b/both/checkpoint-1000/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e272e2f06075f312e6c99e8f22d0b1cb1528f490 --- /dev/null +++ b/both/checkpoint-1000/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:419a70b4f4b877ad856405128d1fe6646bf0d0116e5fec61f4b350d65497dc0c +size 4915916176 diff --git a/both/checkpoint-1000/model-00004-of-00004.safetensors b/both/checkpoint-1000/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f8ef3c2b66818ff2f5328438fdcfb2a676cfa80c --- /dev/null +++ b/both/checkpoint-1000/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c68b5fd92e6138cb15f7dec4aeeb952076a0a2ee94d33d7bb7f71f0f89861ef +size 1168138808 diff --git a/both/checkpoint-1000/model.safetensors.index.json b/both/checkpoint-1000/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..a054aad8cf0cb5671e2e10b3bf817725a4bec031 --- /dev/null +++ b/both/checkpoint-1000/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:146776fce3f6db1103aa6f249e65ee5544c5923ce6f971b092eee79aa6e5d37b +size 23950 diff --git a/both/checkpoint-1000/rng_state_0.pth b/both/checkpoint-1000/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..5a7c482c30381cd512ccc35fe322d8a34fbf5207 --- /dev/null +++ b/both/checkpoint-1000/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:308f94f9a5c24e1bad5c393d56ae7af7782600f4e791d9c6ac35b22fff2105b6 +size 15024 diff --git a/both/checkpoint-1000/rng_state_1.pth b/both/checkpoint-1000/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..7b862c21b28bbd89ce6b4fb681d41be05f175599 --- /dev/null +++ b/both/checkpoint-1000/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b056f3c23cb32dc77a2ec9e7651e0b64e4440e21f0fdf969b86bfc56a1cbdf06 +size 15024 diff --git a/both/checkpoint-1000/rng_state_2.pth b/both/checkpoint-1000/rng_state_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..d86ce886844e0298f058d67065e5eeb27ffe7e48 --- /dev/null +++ b/both/checkpoint-1000/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3f8a05714bc528f4885a2816181652f2303b3e8150f89b56aaee6bec56aa520 +size 15024 diff --git a/both/checkpoint-1000/rng_state_3.pth b/both/checkpoint-1000/rng_state_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..10733f5da657367adf3f67760028644c0839660f --- /dev/null +++ b/both/checkpoint-1000/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f755bd3c330281961e5c03af9d10ce8c1e1678619d384f6f1fd5fd7dce2ff50 +size 15024 diff --git a/both/checkpoint-1000/scheduler.pt b/both/checkpoint-1000/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..6555bc5d2a1c9607353e1bce02fc505e67a4cc5b --- /dev/null +++ b/both/checkpoint-1000/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:558d047f891b2c7bbd6966127f8da9724f087e389774ace2762134bcf3f682af +size 1064 diff --git a/both/checkpoint-1000/special_tokens_map.json b/both/checkpoint-1000/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..92d7fab6729d723f45689803b99006fec945cc47 --- /dev/null +++ b/both/checkpoint-1000/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b1835caa5b4d70acaa210fa222b0036f1882f9525c4660fd4810fb3e1e40ff8 +size 325 diff --git a/both/checkpoint-1000/tokenizer.json b/both/checkpoint-1000/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..9a62752e448c5a24e4fd09229f9378fa6fd092b3 --- /dev/null +++ b/both/checkpoint-1000/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e134af98b985517b4f068e3755ae90d4e9cd2d45d328325dc503f1c6b2d06cc7 +size 9085698 diff --git a/both/checkpoint-1000/tokenizer_config.json b/both/checkpoint-1000/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..135d44d142a6294788d8f0f65143051adb44978c --- /dev/null +++ b/both/checkpoint-1000/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f90c3ed3cf014cbefa253ee894b17ee9e3a439014b369c8dc4f3eedcf3dc774 +size 51248 diff --git a/both/checkpoint-1000/trainer_state.json b/both/checkpoint-1000/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..cf3d94a3dcf95f0190fe6846db148b84e6236cf7 --- /dev/null +++ b/both/checkpoint-1000/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4620aee32c6c0f3a8de6b16b2f4534206cb3a5226be69386680dac15e7c2a62b +size 2763 diff --git a/both/checkpoint-1000/training_args.bin b/both/checkpoint-1000/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..fd07e9b4e8af227f0b4be5f0a73185d97439e4a6 --- /dev/null +++ b/both/checkpoint-1000/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2dedee63b0174fb530781d6756adeec455b4d8453401cf9b524e3de9c01851c3 +size 7032 diff --git a/both/checkpoint-1000/zero_to_fp32.py b/both/checkpoint-1000/zero_to_fp32.py new file mode 100755 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/both/checkpoint-1000/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/generate_strategy/checkpoint-19208/config.json b/generate_strategy/checkpoint-19208/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4b5dd770a9299bba795ab05ced2568c99126d388 --- /dev/null +++ b/generate_strategy/checkpoint-19208/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce500a6a428ad30b64a39e37283ed3798433f50f1ce69076d582548a9ee9a34b +size 748 diff --git a/generate_strategy/checkpoint-19208/generation_config.json b/generate_strategy/checkpoint-19208/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..808ce3668b201a139d5052bc24546a1c4fffa4ac --- /dev/null +++ b/generate_strategy/checkpoint-19208/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8621b054f7ed1a0197b27d579874114724c840fa0efe0cdc6ca2fd1ac07df6bb +size 194 diff --git a/generate_strategy/checkpoint-19208/latest b/generate_strategy/checkpoint-19208/latest new file mode 100644 index 0000000000000000000000000000000000000000..6ccd66a2151500380927f3754aedd0bfc9be23ee --- /dev/null +++ b/generate_strategy/checkpoint-19208/latest @@ -0,0 +1 @@ +global_step19204 \ No newline at end of file diff --git a/generate_strategy/checkpoint-19208/model-00001-of-00004.safetensors b/generate_strategy/checkpoint-19208/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e6779b0b4c6cdf5945b4f462103c5220b6ac2cfb --- /dev/null +++ b/generate_strategy/checkpoint-19208/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:679dfe090c52caf92b5838be67f7183765351855fac89b7c43861328dbb24e80 +size 4976698672 diff --git a/generate_strategy/checkpoint-19208/model-00002-of-00004.safetensors b/generate_strategy/checkpoint-19208/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..167403a3580aa4c4039bc3ef5f10347d215e7d2c --- /dev/null +++ b/generate_strategy/checkpoint-19208/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbc5434cdd2428cb25c44c86b9f6b0a603b7b8253f35f965276d5af380c6b81b +size 4999802720 diff --git a/generate_strategy/checkpoint-19208/model-00003-of-00004.safetensors b/generate_strategy/checkpoint-19208/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e89483a008b8f804c04379cb7729349759043a80 --- /dev/null +++ b/generate_strategy/checkpoint-19208/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51e43bcda95abe59d86336aa9bfd3fbdd9e118d8930bd6a9664953c053c9a9bc +size 4915916176 diff --git a/generate_strategy/checkpoint-19208/model-00004-of-00004.safetensors b/generate_strategy/checkpoint-19208/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..10a29f0101a61aac22de12c8945c0ea036c66bab --- /dev/null +++ b/generate_strategy/checkpoint-19208/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2f9a9755629abc894bba719c6852dfbf3311d8bf2cbabc726d3117682ba2844 +size 1168138808 diff --git a/generate_strategy/checkpoint-19208/model.safetensors.index.json b/generate_strategy/checkpoint-19208/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..a054aad8cf0cb5671e2e10b3bf817725a4bec031 --- /dev/null +++ b/generate_strategy/checkpoint-19208/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:146776fce3f6db1103aa6f249e65ee5544c5923ce6f971b092eee79aa6e5d37b +size 23950 diff --git a/generate_strategy/checkpoint-19208/rng_state_0.pth b/generate_strategy/checkpoint-19208/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..331a542ef30cc221562b6a988bba872aca28732e --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb7c3bc1248de8b4739437317b988d953fd64a5de9736606d74f9c8277f1b485 +size 15984 diff --git a/generate_strategy/checkpoint-19208/rng_state_1.pth b/generate_strategy/checkpoint-19208/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..d445f1a845bda18b54837a3234302870193ebea4 --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8e571d57a85eb2cdabf3f46c86e446bdb7d26aba8b1467b5e4b5bbe29ad42a7 +size 15984 diff --git a/generate_strategy/checkpoint-19208/rng_state_2.pth b/generate_strategy/checkpoint-19208/rng_state_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..3a1a5fda176cefd8a1f05e423f2c82ed9f2333bf --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:489e5542988617525a395c45dc83ec6bf25b473812e139122f0a3f3d92f031d0 +size 15984 diff --git a/generate_strategy/checkpoint-19208/rng_state_3.pth b/generate_strategy/checkpoint-19208/rng_state_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..a7495a1bc89c5532615f548b4a177c4b6de82a0a --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cd77682efb711872c5be25e87e87a2726a2e7105422cddd00f04da7be35ca20 +size 15984 diff --git a/generate_strategy/checkpoint-19208/rng_state_4.pth b/generate_strategy/checkpoint-19208/rng_state_4.pth new file mode 100644 index 0000000000000000000000000000000000000000..a0dd539c338038495aec8fdc04c5e6d165086b28 --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e44d9e7d535f5fbcd7cfef16ba22d32d5f445aacceba782a05df1f97d47a608a +size 15984 diff --git a/generate_strategy/checkpoint-19208/rng_state_5.pth b/generate_strategy/checkpoint-19208/rng_state_5.pth new file mode 100644 index 0000000000000000000000000000000000000000..bd7cb309d087786d365a3ca391edef06504b3bb4 --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_5.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a107290a0d9898930bc6abe369ee246ef7322541985fc2a5320e7775f5ea5c88 +size 15984 diff --git a/generate_strategy/checkpoint-19208/rng_state_6.pth b/generate_strategy/checkpoint-19208/rng_state_6.pth new file mode 100644 index 0000000000000000000000000000000000000000..3c760c81b8bffb4ba6cb4dcda4460911ef5e78df --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_6.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88ab49d56ee4079c2a208376064f825918f070addc8f0c58c5c594265f9e8a78 +size 15984 diff --git a/generate_strategy/checkpoint-19208/rng_state_7.pth b/generate_strategy/checkpoint-19208/rng_state_7.pth new file mode 100644 index 0000000000000000000000000000000000000000..62523a33304462480531f2f10d91dcdd14562719 --- /dev/null +++ b/generate_strategy/checkpoint-19208/rng_state_7.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d15033d06420b17d80db45c89544170faa67833d5a0d9c30a51a38a1102b073 +size 15984 diff --git a/generate_strategy/checkpoint-19208/scheduler.pt b/generate_strategy/checkpoint-19208/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..200e917e590f65c4d828c2dc613bcd34c26dcdfd --- /dev/null +++ b/generate_strategy/checkpoint-19208/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c563ec26d3e6a26a391a7a14609370350d4e8af185e7a822f099a4fa127f834f +size 1064 diff --git a/generate_strategy/checkpoint-19208/special_tokens_map.json b/generate_strategy/checkpoint-19208/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..7eeb1c57604498702a809e216ea13069a45f277e --- /dev/null +++ b/generate_strategy/checkpoint-19208/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b82a7fb78938eefe5da710322cf3cdb591d87c175c20d7bf852c27a36de2e472 +size 650 diff --git a/generate_strategy/checkpoint-19208/tokenizer.json b/generate_strategy/checkpoint-19208/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..172311123ab62378f1f6d90f3068a676b7d939ed --- /dev/null +++ b/generate_strategy/checkpoint-19208/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c1dcab308e7cf5970ea38815e0a62887d705c5b436f869ca27a5dcdd40c36a6 +size 17210148 diff --git a/generate_strategy/checkpoint-19208/tokenizer_config.json b/generate_strategy/checkpoint-19208/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..51ac8674d8345b4b713da5c226f1b6d05e2fae38 --- /dev/null +++ b/generate_strategy/checkpoint-19208/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6fb0b3cee38fd6f0d4042c7730ffab3839b838c66ee64ea3b01811d3d4473b3 +size 51319 diff --git a/generate_strategy/checkpoint-19208/trainer_state.json b/generate_strategy/checkpoint-19208/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..c01bfff052d57d992f551d52968860f8c63d9404 --- /dev/null +++ b/generate_strategy/checkpoint-19208/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:794ad2e756818f226fac8e3882b47276ca30746dc932b8f1384bf9fc9292a279 +size 35878 diff --git a/generate_strategy/checkpoint-19208/training_args.bin b/generate_strategy/checkpoint-19208/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..773ce6240443c97c5f4f17d0c292e93b3f620d6d --- /dev/null +++ b/generate_strategy/checkpoint-19208/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c4854ece4d5ef51aa764407c1d839019391947a59f720f2fc5ec761b53b0838 +size 7416 diff --git a/generate_strategy/checkpoint-19208/zero_to_fp32.py b/generate_strategy/checkpoint-19208/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/generate_strategy/checkpoint-19208/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/no_explain/checkpoint-4000/config.json b/no_explain/checkpoint-4000/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4b5dd770a9299bba795ab05ced2568c99126d388 --- /dev/null +++ b/no_explain/checkpoint-4000/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce500a6a428ad30b64a39e37283ed3798433f50f1ce69076d582548a9ee9a34b +size 748 diff --git a/no_explain/checkpoint-4000/generation_config.json b/no_explain/checkpoint-4000/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..808ce3668b201a139d5052bc24546a1c4fffa4ac --- /dev/null +++ b/no_explain/checkpoint-4000/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8621b054f7ed1a0197b27d579874114724c840fa0efe0cdc6ca2fd1ac07df6bb +size 194 diff --git a/no_explain/checkpoint-4000/latest b/no_explain/checkpoint-4000/latest new file mode 100644 index 0000000000000000000000000000000000000000..bde045b84f7344a502489c347cb8527c3cce2ef5 --- /dev/null +++ b/no_explain/checkpoint-4000/latest @@ -0,0 +1 @@ +global_step3998 \ No newline at end of file diff --git a/no_explain/checkpoint-4000/model-00001-of-00004.safetensors b/no_explain/checkpoint-4000/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..3c671cd9c9bcfc161ae9ac606058f417b25843ee --- /dev/null +++ b/no_explain/checkpoint-4000/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47dde464d2aefd37eda2815b5ba61b57d361c3bce52e792d3c1ab69102a55110 +size 4976698672 diff --git a/no_explain/checkpoint-4000/model-00002-of-00004.safetensors b/no_explain/checkpoint-4000/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..cfc74a973715e43594e2678567e12aab1c679e92 --- /dev/null +++ b/no_explain/checkpoint-4000/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfde70000ea34a2629d392c10f2bc415d8d6fde186a2d31d992fc0c70386f154 +size 4999802720 diff --git a/no_explain/checkpoint-4000/model-00003-of-00004.safetensors b/no_explain/checkpoint-4000/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..704a9ff77226551034fd2040957b56800db344f3 --- /dev/null +++ b/no_explain/checkpoint-4000/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:366054a4bee6871b7747951631c6f684e622cadc61ffb309bb794f49ca09d355 +size 4915916176 diff --git a/no_explain/checkpoint-4000/model-00004-of-00004.safetensors b/no_explain/checkpoint-4000/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..224d9585cfa423b5bfa4d3849e745e8f82d57429 --- /dev/null +++ b/no_explain/checkpoint-4000/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3fd155657116ece6520b0be23250ac218e093f5e3fb69a33fcb0227a3a62ce3 +size 1168138808 diff --git a/no_explain/checkpoint-4000/model.safetensors.index.json b/no_explain/checkpoint-4000/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..a054aad8cf0cb5671e2e10b3bf817725a4bec031 --- /dev/null +++ b/no_explain/checkpoint-4000/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:146776fce3f6db1103aa6f249e65ee5544c5923ce6f971b092eee79aa6e5d37b +size 23950 diff --git a/no_explain/checkpoint-4000/rng_state_0.pth b/no_explain/checkpoint-4000/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..6a74f25da28f01a2e6b66587824ee5f5cc9be737 --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ee195ebde9bf012f945f068f133e7fe22fef5450c496607e3ef11cc2034a186 +size 15984 diff --git a/no_explain/checkpoint-4000/rng_state_1.pth b/no_explain/checkpoint-4000/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..f44ddc47315653477728c971b4ea191a3df8b92c --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf0fe1a3315d60b197207c5cb249d0ce4f9ce6d7585e696276d9ffbcb5379893 +size 15984 diff --git a/no_explain/checkpoint-4000/rng_state_2.pth b/no_explain/checkpoint-4000/rng_state_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..04636b9eca6484a4339eaa1e3acdf15d42d493b3 --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01c5bd6eae04542162b3e94245555bd81312524066bc01d0ebbfc4fd8554240e +size 15984 diff --git a/no_explain/checkpoint-4000/rng_state_3.pth b/no_explain/checkpoint-4000/rng_state_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..05435e407541728c3159054a4beb6705039a8ddf --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45b74942c68b00d657cfce186b0eeb4aa8f52efa04b114803b605fee8de45972 +size 15984 diff --git a/no_explain/checkpoint-4000/rng_state_4.pth b/no_explain/checkpoint-4000/rng_state_4.pth new file mode 100644 index 0000000000000000000000000000000000000000..94fdf5f2c3e5df27424e6482bf52255531147a23 --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cd66dd2ba958fc9929441817d8154abbd929c0aa9cd66ff3171965bdaaf5d78 +size 15984 diff --git a/no_explain/checkpoint-4000/rng_state_5.pth b/no_explain/checkpoint-4000/rng_state_5.pth new file mode 100644 index 0000000000000000000000000000000000000000..da6e37fc011d97a1512e1e746bdd410a738c018a --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_5.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89eeedefdd62514d0130acc330a5c08e9774c95d38c60997905cfd65fc54b710 +size 15984 diff --git a/no_explain/checkpoint-4000/rng_state_6.pth b/no_explain/checkpoint-4000/rng_state_6.pth new file mode 100644 index 0000000000000000000000000000000000000000..751fd85c617e15dee9713bc0f0c533af5bd18c8e --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_6.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f43ced939100082608f57561a10e1888e69210c80675068db530c5815889910e +size 15984 diff --git a/no_explain/checkpoint-4000/rng_state_7.pth b/no_explain/checkpoint-4000/rng_state_7.pth new file mode 100644 index 0000000000000000000000000000000000000000..4aacf54fa8285b7e199a7cd62f1ee3d8b9beb5e5 --- /dev/null +++ b/no_explain/checkpoint-4000/rng_state_7.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d8d6ee244d99525e7004ae3f02d44ae63082d81fbbab7306f641ac6aeeb736f +size 15984 diff --git a/no_explain/checkpoint-4000/scheduler.pt b/no_explain/checkpoint-4000/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..f3b39f0ea2d9092bec75f0da2301e46c33457524 --- /dev/null +++ b/no_explain/checkpoint-4000/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:624e83d8bb1c2a780d4eb55df92ee2a5d78647169f3aca1e61c43c0cf57e3359 +size 1064 diff --git a/no_explain/checkpoint-4000/special_tokens_map.json b/no_explain/checkpoint-4000/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..7eeb1c57604498702a809e216ea13069a45f277e --- /dev/null +++ b/no_explain/checkpoint-4000/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b82a7fb78938eefe5da710322cf3cdb591d87c175c20d7bf852c27a36de2e472 +size 650 diff --git a/no_explain/checkpoint-4000/tokenizer.json b/no_explain/checkpoint-4000/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..172311123ab62378f1f6d90f3068a676b7d939ed --- /dev/null +++ b/no_explain/checkpoint-4000/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c1dcab308e7cf5970ea38815e0a62887d705c5b436f869ca27a5dcdd40c36a6 +size 17210148 diff --git a/no_explain/checkpoint-4000/tokenizer_config.json b/no_explain/checkpoint-4000/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..51ac8674d8345b4b713da5c226f1b6d05e2fae38 --- /dev/null +++ b/no_explain/checkpoint-4000/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6fb0b3cee38fd6f0d4042c7730ffab3839b838c66ee64ea3b01811d3d4473b3 +size 51319 diff --git a/no_explain/checkpoint-4000/trainer_state.json b/no_explain/checkpoint-4000/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..a1b369a69c0c1d3f408e3f1bcf683ca3da41449a --- /dev/null +++ b/no_explain/checkpoint-4000/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:203d86aba5c0edca2a683fe98f6460936b6ab47b5413a6193569eaa0799e9afc +size 8652 diff --git a/no_explain/checkpoint-4000/training_args.bin b/no_explain/checkpoint-4000/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..307671da66dbb1b7ea5eba31217babf26111f74f --- /dev/null +++ b/no_explain/checkpoint-4000/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44b434e8cddfefc8be8a47c2feee91ad07bcea3a34db2ebd66e1953e926e28aa +size 7416 diff --git a/no_explain/checkpoint-4000/zero_to_fp32.py b/no_explain/checkpoint-4000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/no_explain/checkpoint-4000/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/strategy/checkpoint-5000/config.json b/strategy/checkpoint-5000/config.json new file mode 100644 index 0000000000000000000000000000000000000000..da6f9ba3622f70e966728e9939da080ce8df3df3 --- /dev/null +++ b/strategy/checkpoint-5000/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddcfa3c8f7da780e23fc2990aba55eef954aa0eb9e934a58c5536953e79a2b1c +size 729 diff --git a/strategy/checkpoint-5000/generation_config.json b/strategy/checkpoint-5000/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e0796b3dbfe28ebffe364cefdfc7f7116df2fae4 --- /dev/null +++ b/strategy/checkpoint-5000/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f312a9b22d84af6aa4883b93869f2f480a4cdc78a54cf6e7d5b2069f019ff4c4 +size 194 diff --git a/strategy/checkpoint-5000/latest b/strategy/checkpoint-5000/latest new file mode 100644 index 0000000000000000000000000000000000000000..f805186fa43374540c3fa51dfd3cca9ac06e56a5 --- /dev/null +++ b/strategy/checkpoint-5000/latest @@ -0,0 +1 @@ +global_step5000 \ No newline at end of file diff --git a/strategy/checkpoint-5000/model-00001-of-00004.safetensors b/strategy/checkpoint-5000/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..c978242533759a2edb4d3b98c1bfac5a5a2dd4bc --- /dev/null +++ b/strategy/checkpoint-5000/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51dfd7a50d92b1eb1993ce6e6988abeb5e2809b747741dbe5b16eb95067d51e2 +size 4976698672 diff --git a/strategy/checkpoint-5000/model-00002-of-00004.safetensors b/strategy/checkpoint-5000/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..24113d4dfcd22f658bd59aca5dfd32cfbdbb3c6b --- /dev/null +++ b/strategy/checkpoint-5000/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58b937f40a52f3c202847a9f2bd058e35629a151bc059a1968677ff1edbccccc +size 4999802720 diff --git a/strategy/checkpoint-5000/model-00003-of-00004.safetensors b/strategy/checkpoint-5000/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..cb4950f36e31a82b2ba28e2308724ee464c565c0 --- /dev/null +++ b/strategy/checkpoint-5000/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d31967b827d1004761c9e8a19a0bc16dee55f48a3122733818c47ad616c75ab1 +size 4915916176 diff --git a/strategy/checkpoint-5000/model-00004-of-00004.safetensors b/strategy/checkpoint-5000/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f55598331d129f923f6b2a59f1e1a6a90f90bd61 --- /dev/null +++ b/strategy/checkpoint-5000/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58ef87fa18407bc789a30b899db8cfe72bb1a72fa8763845c0ae3361211d94b0 +size 1168138808 diff --git a/strategy/checkpoint-5000/model.safetensors.index.json b/strategy/checkpoint-5000/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..a054aad8cf0cb5671e2e10b3bf817725a4bec031 --- /dev/null +++ b/strategy/checkpoint-5000/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:146776fce3f6db1103aa6f249e65ee5544c5923ce6f971b092eee79aa6e5d37b +size 23950 diff --git a/strategy/checkpoint-5000/rng_state_0.pth b/strategy/checkpoint-5000/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..d4ade713ef57d0535c32a9251c786bc57de03d06 --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb1165242405b17b3d6a8186ae61b13dcb1faa5a54320bebd74ef8d71b964bf7 +size 15984 diff --git a/strategy/checkpoint-5000/rng_state_1.pth b/strategy/checkpoint-5000/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..d91c511b147b4dd17988903c57adcefb6c1f20b0 --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:562c262916c9997ec644c42fed9655ab28706b74fca20290ca921c4761d6a4b0 +size 15984 diff --git a/strategy/checkpoint-5000/rng_state_2.pth b/strategy/checkpoint-5000/rng_state_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..f71e829b3e3570a540263d07783c4e906a78a803 --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8d40f8118f513299624ded0a9bcf09778b961635615090409394d4f96f928f6 +size 15984 diff --git a/strategy/checkpoint-5000/rng_state_3.pth b/strategy/checkpoint-5000/rng_state_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..be7f0176676a7c526bb10cbb336b2afa89d8841c --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4391f924238a4cb855c4cbdc6d1a14954f785431c75997d05c7a4ee6615dae7 +size 15984 diff --git a/strategy/checkpoint-5000/rng_state_4.pth b/strategy/checkpoint-5000/rng_state_4.pth new file mode 100644 index 0000000000000000000000000000000000000000..8dd1a877dd1f03799067fd08739e82b9f2cd2ad3 --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be7b19bb9543a16bf9f4cd96466ac581436f63070f5815f3a7ba57980608994f +size 15984 diff --git a/strategy/checkpoint-5000/rng_state_5.pth b/strategy/checkpoint-5000/rng_state_5.pth new file mode 100644 index 0000000000000000000000000000000000000000..dcf1b720014f72a27a09ab9ef8570430a8e3c96d --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_5.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97da4a1ede0a3e0f96411cacd5bfdf84d9355198f7aadc9bcb8be41122043f63 +size 15984 diff --git a/strategy/checkpoint-5000/rng_state_6.pth b/strategy/checkpoint-5000/rng_state_6.pth new file mode 100644 index 0000000000000000000000000000000000000000..2b58cbeed7b25ef61c6439aced60df473cbaf6d4 --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_6.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:544cb6421b975bd5d2b2360a4e666003794e6197ae654d2ad963cd6572a86ede +size 15984 diff --git a/strategy/checkpoint-5000/rng_state_7.pth b/strategy/checkpoint-5000/rng_state_7.pth new file mode 100644 index 0000000000000000000000000000000000000000..36a7dcefe0e0264868d40586546699306878a454 --- /dev/null +++ b/strategy/checkpoint-5000/rng_state_7.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8d6eb32a23f3bef6262bbcb2eda724b2fd6f5e579969aa27c71a5971331722b +size 15984 diff --git a/strategy/checkpoint-5000/scheduler.pt b/strategy/checkpoint-5000/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..4f9af7d1258bd81383960c0f30257b21293f27f4 --- /dev/null +++ b/strategy/checkpoint-5000/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:211c0a4c9e8e9be4029e22dafd27f163573fae7d751781353b1f2e1bba8992c3 +size 1064 diff --git a/strategy/checkpoint-5000/special_tokens_map.json b/strategy/checkpoint-5000/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..92d7fab6729d723f45689803b99006fec945cc47 --- /dev/null +++ b/strategy/checkpoint-5000/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b1835caa5b4d70acaa210fa222b0036f1882f9525c4660fd4810fb3e1e40ff8 +size 325 diff --git a/strategy/checkpoint-5000/tokenizer.json b/strategy/checkpoint-5000/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..9a62752e448c5a24e4fd09229f9378fa6fd092b3 --- /dev/null +++ b/strategy/checkpoint-5000/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e134af98b985517b4f068e3755ae90d4e9cd2d45d328325dc503f1c6b2d06cc7 +size 9085698 diff --git a/strategy/checkpoint-5000/tokenizer_config.json b/strategy/checkpoint-5000/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..135d44d142a6294788d8f0f65143051adb44978c --- /dev/null +++ b/strategy/checkpoint-5000/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f90c3ed3cf014cbefa253ee894b17ee9e3a439014b369c8dc4f3eedcf3dc774 +size 51248 diff --git a/strategy/checkpoint-5000/trainer_state.json b/strategy/checkpoint-5000/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..7be587be06dc6e624aaa84365a6468df11f46d1a --- /dev/null +++ b/strategy/checkpoint-5000/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:808e701315411302bbcb4cd1e45cde1f82a06628ad3878a53e7336eb00085429 +size 10541 diff --git a/strategy/checkpoint-5000/training_args.bin b/strategy/checkpoint-5000/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..8cac593e146173e7b3cb37977797bca7b931318b --- /dev/null +++ b/strategy/checkpoint-5000/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac3b244331aca2542caad5f3a39e4b69e8873ff4657e9822355e53be4684d2c9 +size 7032 diff --git a/strategy/checkpoint-5000/zero_to_fp32.py b/strategy/checkpoint-5000/zero_to_fp32.py new file mode 100755 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/strategy/checkpoint-5000/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/tactic/checkpoint-3075/config.json b/tactic/checkpoint-3075/config.json new file mode 100644 index 0000000000000000000000000000000000000000..da6f9ba3622f70e966728e9939da080ce8df3df3 --- /dev/null +++ b/tactic/checkpoint-3075/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddcfa3c8f7da780e23fc2990aba55eef954aa0eb9e934a58c5536953e79a2b1c +size 729 diff --git a/tactic/checkpoint-3075/generation_config.json b/tactic/checkpoint-3075/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e0796b3dbfe28ebffe364cefdfc7f7116df2fae4 --- /dev/null +++ b/tactic/checkpoint-3075/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f312a9b22d84af6aa4883b93869f2f480a4cdc78a54cf6e7d5b2069f019ff4c4 +size 194 diff --git a/tactic/checkpoint-3075/latest b/tactic/checkpoint-3075/latest new file mode 100644 index 0000000000000000000000000000000000000000..972b1c424a128381c5fe3bc9edc6617b37340666 --- /dev/null +++ b/tactic/checkpoint-3075/latest @@ -0,0 +1 @@ +global_step3075 \ No newline at end of file diff --git a/tactic/checkpoint-3075/model-00001-of-00004.safetensors b/tactic/checkpoint-3075/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..68256cacc9f62980c9666f2290d98275e90455f4 --- /dev/null +++ b/tactic/checkpoint-3075/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0df2fe6a5dc54b765fd8523a9dad5fe48d5ce1dda83c61a2926d2db2344de077 +size 4976698672 diff --git a/tactic/checkpoint-3075/model-00002-of-00004.safetensors b/tactic/checkpoint-3075/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f3de357eb1d7eabc6bb8f284d129608a16ac4549 --- /dev/null +++ b/tactic/checkpoint-3075/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:036976117de02fea03c1b2db7dd88402472fa3718de410b951f33313a618e220 +size 4999802720 diff --git a/tactic/checkpoint-3075/model-00003-of-00004.safetensors b/tactic/checkpoint-3075/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e902612e261384b2ed8c78ca3a1ecbf6814429b4 --- /dev/null +++ b/tactic/checkpoint-3075/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ae8245dabdcfd29fcca60adbe234ff294f82f5fa4397d2a5972c9f6acf7b908 +size 4915916176 diff --git a/tactic/checkpoint-3075/model-00004-of-00004.safetensors b/tactic/checkpoint-3075/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..9145443c670ce7a3f3bc2ede08c0c653f5374811 --- /dev/null +++ b/tactic/checkpoint-3075/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9b1a22c904f85c283bee9524b8ca4cbf60d5d1c45f633e4fea4858634422177 +size 1168138808 diff --git a/tactic/checkpoint-3075/model.safetensors.index.json b/tactic/checkpoint-3075/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..a054aad8cf0cb5671e2e10b3bf817725a4bec031 --- /dev/null +++ b/tactic/checkpoint-3075/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:146776fce3f6db1103aa6f249e65ee5544c5923ce6f971b092eee79aa6e5d37b +size 23950 diff --git a/tactic/checkpoint-3075/rng_state_0.pth b/tactic/checkpoint-3075/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..b4f7aff8787e77abdd3de7299719c4c21fc26258 --- /dev/null +++ b/tactic/checkpoint-3075/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee97cd82dba4d425fdd8dfdb88d4a43d0d4b1979b5c81ab4a24914fb00d4f332 +size 15024 diff --git a/tactic/checkpoint-3075/rng_state_1.pth b/tactic/checkpoint-3075/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..60e171edb0868d2d1932468dd935beea673dfb02 --- /dev/null +++ b/tactic/checkpoint-3075/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91dad95440fb85dc4a31745642117165c1a72173b2e389679ea8c0b2b6fcd7e2 +size 15024 diff --git a/tactic/checkpoint-3075/rng_state_2.pth b/tactic/checkpoint-3075/rng_state_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..719d1d591f4eba9f3f0ae8eb275150361dde6d12 --- /dev/null +++ b/tactic/checkpoint-3075/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98698326b023c2af02c94f18726ce52c7f7a6fe290734dd7edbe99bc807fcfa0 +size 15024 diff --git a/tactic/checkpoint-3075/rng_state_3.pth b/tactic/checkpoint-3075/rng_state_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..45dc07c6b18b85ced4b0a4155cac795581cc18a5 --- /dev/null +++ b/tactic/checkpoint-3075/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:708e7c6b5bf8a327e688779ebc08830ce249928bcb1ff5c82b1b1d0bf6d2660b +size 15024 diff --git a/tactic/checkpoint-3075/scheduler.pt b/tactic/checkpoint-3075/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..2f054193629d972af03d91af9dc6d83250807e87 --- /dev/null +++ b/tactic/checkpoint-3075/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a41fb1a121ff4fa3c097da2b55a0026e1528481e35768e3cefad00910bb40a26 +size 1064 diff --git a/tactic/checkpoint-3075/special_tokens_map.json b/tactic/checkpoint-3075/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..92d7fab6729d723f45689803b99006fec945cc47 --- /dev/null +++ b/tactic/checkpoint-3075/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b1835caa5b4d70acaa210fa222b0036f1882f9525c4660fd4810fb3e1e40ff8 +size 325 diff --git a/tactic/checkpoint-3075/tokenizer.json b/tactic/checkpoint-3075/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..9a62752e448c5a24e4fd09229f9378fa6fd092b3 --- /dev/null +++ b/tactic/checkpoint-3075/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e134af98b985517b4f068e3755ae90d4e9cd2d45d328325dc503f1c6b2d06cc7 +size 9085698 diff --git a/tactic/checkpoint-3075/tokenizer_config.json b/tactic/checkpoint-3075/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..135d44d142a6294788d8f0f65143051adb44978c --- /dev/null +++ b/tactic/checkpoint-3075/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f90c3ed3cf014cbefa253ee894b17ee9e3a439014b369c8dc4f3eedcf3dc774 +size 51248 diff --git a/tactic/checkpoint-3075/trainer_state.json b/tactic/checkpoint-3075/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..db68e40fa62bcc5b56f375793536b8deb4b2e86c --- /dev/null +++ b/tactic/checkpoint-3075/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43da695933d2635aa68757e5e5327a202dc0caf3435560308295869e5964b324 +size 6661 diff --git a/tactic/checkpoint-3075/training_args.bin b/tactic/checkpoint-3075/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..716d22315af75ad8a5eff062fc02b96e4dbf7e0d --- /dev/null +++ b/tactic/checkpoint-3075/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a38c035661fde1c2b8ab0a69249e58aff2920239aa6f6fdd30935768c60f5bf +size 7032 diff --git a/tactic/checkpoint-3075/zero_to_fp32.py b/tactic/checkpoint-3075/zero_to_fp32.py new file mode 100755 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/tactic/checkpoint-3075/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters)