Megatron-LM 的 ckpt format 探析
从一次 teacher checkpoint 加载报错出发,梳理 Megatron-LM 中 torch legacy 与 torch_dist checkpoint 格式的差异。
一、问题引入
今天在一台新机器用Megatron-lm跑一个蒸馏实验时,启动阶段遇到了一个bug。
FileNotFoundError: Legacy torch teacher checkpoint is missing the shard required
for the current tensor-parallel rank. Legacy teacher checkpoints must have
TP-compatible shards on disk (for example, tp1 legacy checkpoints cannot be
loaded directly into tp2).
张量并行度为2,但是导入的Teacher ckpt是tp1的,而Student的ckpt是tp1pp1版本的 而且是先加载Student Model再加载teacher model,两个都是tp1,为什么student可以成功加载,但是teacher加载就会报错呢?
回头check我的训练配置
mp_rank_00
- logs
- model_optim_rng.pt
确实是对应上了那个报错,Legacy teacher checkpoints must have TP-compatible shards on disk (for example, tp1 legacy checkpoints cannot be loaded directly into tp2).后续换成了tp2的teacher ckpt,就能成功跑通。
但是我还想知道是
- 为什么student model就能正常load而teacher不行?
- 在哪里设置了/表明了这个teacher checkpoints是”legacy”的?
- 与legacy相对应的格式是什么?
- 为什么legacy的save tp数必须与Load tp数相同而另一种不需要?
因此又作了进一步查看。
让Codex帮我查找具体的定位【关键是日志,从参数解析、模型加载、模型训练全流程都要留痕在LOG文件里面,这对跟踪训练与DEBUG非常有帮助】
Codex一番调查,得到了一些结论:
- 因为student与teacher的ckpt的格式不一样。teacher_ckpt是torch(也就是所谓的legacy),student是 torch_dist(分布式torch)。
- 有两处迹象
- 日志参数显示:
ckpt_format = torch_distuse_dist_ckpt = Trueexport_kd_teacher_ckpt_format = torch前面两个是student(主模型)的配置,最后是teacher的配置
- 很显然这些ckpt格式并不是作为超参数传入的,而是自动解析ckpt目录得到的,因此还可以通过查看ckpt里面的具体文件结构来判断
- 查看Student ckpt:
release/common.pt、release/metadata.json、release/__0_0.distcp、release/modelopt_run_config.yaml - 查看teacher ckpt目录结构:
mp_rank_00/model_optim_rng.pt、mp_rank_00/logs可以看到二者的结构确实有很大不同
- 查看Student ckpt:
- 所谓的legacy其实就是’torch’格式的ckpt(模型权重文件后缀为pt),而与之相对的就是’torch_dist’格式(模型权重文件后缀为distcp(distributed checkpoint))
- 通过搜索相关帖子得以解答
- 日志参数显示:
二、源码探析
2.1 传参层
源码位置:megatron/training/arguments.py
def _add_checkpointing_args(parser):
group = parser.add_argument_group(title='checkpointing')
### 省略其他参数
group.add_argument('--ckpt-format', default='torch_dist',
choices=['torch', 'torch_dist', 'zarr', 'torch_dcp', 'fsdp_dtensor'],
help='Checkpoint format to use. torch is the format used by torch.save/load.'
' torch_dist is a megatron built-in distributed checkpointing format.
' torch_dcp is the torch.distributed.checkpoint format.'
' fsdp_dtensor is a torch DCP native, Megatron FSDP training-specific checkpoint format.')
没有文档式的参数说明,但是源码库里面有。可以看到 ckpt_format参数一共有四种取值,第一种 torch就是所谓的 legacy torch了,而且默认值就是 torch_dist。
源码位置:gdn_pretrain_template.sh
### 省略其他参数
TRAINING_ARGS=(
--micro-batch-size ${BATCH_SIZE}
--global-batch-size ${GLOBAL_BATCH_SIZE}
--lr ${LR}
--train-samples ${TRAIN_SAMPLES}
--lr-warmup-samples ${LR_WARMUP_SAMPLES}
--lr-decay-samples ${LR_DECAY_SAMPLES}
--lr-decay-style ${LR_DECAY_STYLE}
--min-lr ${MIN_LR}
--dataloader-type cyclic
--num-workers ${GPUS_PER_NODE:-8}
--bf16
--ckpt-format torch_dist
--async-save
--ckpt-fully-parallel-load
--auto-detect-ckpt-format
--dist-ckpt-optim-fully-reshardable
)
这是megatron-lm启动训练的脚本模板,已经提前写死了ckpt-format为 torch_dist,甚至都没有像其他参数一样可以通过环境变量来读取。
源码位置:megatron/post_training/arguments.py
def add_modelopt_args(parser):
"""Add additional arguments for using TensorRT Model Optimizer (modelopt) features."""
group = parser.add_argument_group(title="modelopt-generic")
### 省略其他参数
group.add_argument(
'--export-kd-teacher-ckpt-format',
type=str,
default=None,
choices=['torch', 'torch_dist', 'zarr', 'torch_dcp'],
help="Checkpoint format of teacher model, if different from student's.",
)
由于蒸馏后训练是在ModelOPT里面支持的,所以涉及到teacher model的参数就集中放在 modelopt_args。我的运行脚本中有如下命令行:
export EXTRA_ARGS="${EXTRA_ARGS:---seed 42 --log-params-norm --no-initialization --rerun-mode disabled --export-kd-teacher-ckpt-format torch --dist-ckpt-strictness log_all}"
所以,Teacher model就采用了 torch的ckpt格式,与此前的 torch_dist不同。
2.2 实现层
为什么torch legacy格式的ckpt需要保持save与load时的tp数一致
save侧:
主要在 megatron/training/checkpointing.py里面 save_checkpoint函数的定义
checkpoint_name = get_checkpoint_name(save_dir, iteration, release=release, pipeline_parallel=pipeline_parallel,
tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir
) # 提前确定要保存路径
state_dict = generate_state_dict(
args,
model,
optimizer,
opt_param_scheduler,
rng_state,
iteration=iteration,
optim_sd_kwargs=dict(metadata=sharded_sd_metadata),
model_sd_kwargs=dict(metadata=sharded_sd_metadata),
rerun_state=rerun_state,
)
### 不同的chpt_type对应不同的保存方式,下面介绍legacy的保存方式
assert ckpt_type == CheckpointType.LEGACY
# Save
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name) # 核心,只用了torch的原生save接口
下面可以看一下 checkpoint_name是如何确定的
源码位置:get_checkpoint_name的实现
def get_checkpoint_name(checkpoints_path, iteration, release=False,
pipeline_parallel=None,
tensor_rank=None, pipeline_rank=None,
expert_parallel=None, expert_rank=None,
return_base_dir=False, basename="model_optim_rng.pt"):
"""Determine the directory name for this rank's checkpoint."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
if return_base_dir:
common_path = os.path.join(checkpoints_path, directory)
return common_path
# Use both the tensor and pipeline MP rank.
### MPU是Megatron-Core中定义的一个并行状态管理器(Model Parallel Unit),在分布式训练初始化的时候完成mpu的初始化。
if pipeline_parallel is None: # 如果没有显式传餐是否启动PP,就自动从MPU里面查找PP维度,如果>1就说明开了流水线平行,否则就没开;EP同理
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if expert_parallel is None:
expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1)
if expert_rank is None:
expert_rank = mpu.get_expert_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
# 正式开始欲保存的权重路径的命名
if not pipeline_parallel:
common_path = os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}')
else:
common_path = os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
if expert_parallel:
common_path = common_path + f'_{expert_rank:03d}'
return os.path.join(common_path, basename)
可以发现,切片的命名规则如下:
- mp_rank_0x
- mp_rank_0x_00x
- mp_rank_0x_00x_00x 分别对应
- TP Only
- TP and PP
- TP, PP and EP 可以看出Megatron 的 legacy 命名总是从 tensor_rank 开始,即使 TP=1 也有 mp_rank_00。PP/EP 是额外并行维度,开启后再追加对应 rank。
还可以看一下这里的state_dict是如何构造的:
源码位置:generate_state_dick的实现
def generate_state_dict(
args,
model,
optimizer,
opt_param_scheduler,
rng_state,
iteration=None,
optim_sd_kwargs=None,
model_sd_kwargs=None,
rerun_state=None,
):
"""Generate a state dict from given model, optimizer, scheduler, rng state and others. """
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
if iteration is not None:
state_dict['iteration'] = iteration
for i in range(len(model)):
key = "model"
if len(model) > 1:
key = f"model{i}"
if args.ckpt_format == "torch_dist":
model_sd = model[i].sharded_state_dict(**(model_sd_kwargs or {}))
else: # torch, torch_dcp, fsdp_dtensor
model_sd = model[i].state_dict_for_save_checkpoint()
state_dict[key] = model_sd
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None and not optimizer.is_stub_optimizer:
optimizer_sd = None
if args.ckpt_format == "torch_dist":
optimizer_sd = optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {}))
elif args.ckpt_format == "fsdp_dtensor":
if optim_sd_kwargs is None:
optim_sd_kwargs = {}
if "metadata" not in optim_sd_kwargs:
optim_sd_kwargs["metadata"] = {}
optim_sd_kwargs['metadata'].update(_build_sharded_state_dict_metadata(args))
optimizer_sd = optimizer.sharded_state_dict(state_dict, **optim_sd_kwargs)
else:
optimizer_sd = optimizer.state_dict()
state_dict['optimizer'] = optimizer_sd
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = \
opt_param_scheduler.state_dict()
# Rerun state
if rerun_state:
state_dict['rerun_state_machine'] = rerun_state
# RNG states.
if not args.no_save_rng and rng_state:
state_dict["rng_state"] = rng_state
return state_dict
可以看到: 对于 legacy torch,保存的是:
model[i].state_dict_for_save_checkpoint()
这就是当前 rank 上模型参数的本地 state dict。它没有记录“这个 tensor 是全局 tensor 的第几片”,也没有记录“这个 tensor 沿哪个轴被 TP 切分”。
对于 torch_dist的格式,会在下文介绍
load 侧
legacy torch 格式的 load 入口仍然在 megatron/training/checkpointing.py。
- legacy load 会按“当前运行的 TP/PP rank”去找文件
核心代码在 _load_base_checkpoint():
elif ckpt_format == "torch":
ckpt_type = CheckpointType.LEGACY
# Handle global legacy checkpoint
if rank0:
checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
else:
checkpoint_name = get_checkpoint_name(
load_dir, iteration, release, return_base_dir=False
)
state_dict = torch.load(checkpoint_name, map_location='cpu')
这里的重点是:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=False)
get_checkpoint_name() 会根据当前进程的:
mpu.get_tensor_model_parallel_rank()
mpu.get_pipeline_model_parallel_rank()
mpu.get_expert_model_parallel_rank()
来拼出当前 rank 应该读取的 checkpoint 文件。
所以如果保存时是 TP=1,那么目录里通常只有:
iter_0001000/
mp_rank_00/
model_optim_rng.pt
但如果加载时改成 TP=2,那么两个 TP rank 会分别尝试读:
iter_0001000/
mp_rank_00/model_optim_rng.pt
mp_rank_01/model_optim_rng.pt
这时 mp_rank_01 根本不存在。因此在文件布局层面,legacy torch 格式就已经强依赖 save 时的 TP/PP/EP 切分方式。
- 代码里也显式禁止 legacy TP/PP 不一致
加载后会检查 checkpoint 里的训练参数:
if 'args' in state_dict and not args.finetune and not args.override_opt_param_scheduler:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
check_checkpoint_args() 里有明确判断:
if get_checkpoint_version() >= 3.0 and not args.use_dist_ckpt:
_compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
_compare函数内部进行了checkpoint参数与实际传入参数的断言,上述两个句子确保了ckpt的tp数&pp数与训练配置的tp数&pp数严格相等。
为什么 torch_dist 不强制 save/load TP 一致
torch_dist 的关键特点是:它保存的是带有全局切片描述的 ShardedTensor。
加载时,当前运行配置会重新生成一份 sharded request,告诉 checkpoint 系统当前 rank 需要全局 tensor 的哪一片。因此它可以支持 TP1 save、TP2 load 这类重分片场景。
ShardedTensor因为涉及到整块Tensor的分配与切片,有点复杂,也是Megatron-lm里面的一个比较重要的部分,计划在下一篇里面详细解读。