Feat: add checkpoint loading mechanism#146
Conversation
e8c5dd5 to
0a3deb2
Compare
| @@ -0,0 +1,1060 @@ | |||
| #include "example/common/checkpoint_loader.h" | |||
There was a problem hiding this comment.
这个文件太重了,有一千多行,checkpoint相关的基建和 llama / gpt 的 save / load 都混在一起了。要不要拆分一个example/common/checkpoint_utils.h/.cc,然后保留 gpt2 和 llama3 各自的特化调用?这个可以再讨论一下
There was a problem hiding this comment.
还是按模型拆分吧,通用的公共函数放这里,gpt2/llama3 的特化部分放 example 下模型各自文件夹里。
|
|
||
| named_shard_params_.clear(); | ||
| for (size_t i = 0; i < shard_params_.size(); ++i) { | ||
| named_shard_params_.emplace_back(shard_param_names_[i], shard_params_[i]); |
There was a problem hiding this comment.
这里named_shard_params_有没有可能出现多个相同 name
There was a problem hiding this comment.
这个name不会重复,name对应各层的参数名(模块名加index
0a089ae to
4981cd4
Compare
| @@ -0,0 +1,1060 @@ | |||
| #include "example/common/checkpoint_loader.h" | |||
There was a problem hiding this comment.
还是按模型拆分吧,通用的公共函数放这里,gpt2/llama3 的特化部分放 example 下模型各自文件夹里。
| } // namespace llama3 | ||
|
|
||
| struct ResumeFromCheckpointArgs { | ||
| fLS::clstring resume_root; |
There was a problem hiding this comment.
这里能用 std::filesystem 里的相关类型或者直接 string 吗?
| #include "infini_train/include/nn/modules/transformer/transformer_config.h" | ||
|
|
||
| namespace nn = infini_train::nn; | ||
| namespace infini_train { |
There was a problem hiding this comment.
example 下的内容不属于框架代码,不需要放到 infinitrain namespace 里。
不过看了下目前确实存在一些通用实现(tokenizer 等)应当放到框架代码里,历史遗留的部分我之后提 pr 一起挪。
|
|
||
| // precision | ||
| DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); | ||
| DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving"); |
There was a problem hiding this comment.
https://github.com/NVIDIA/Megatron-LM/blob/main/examples/llama/train_llama3_8b_h100_fp8.sh
这个参数在 megatron 里应该是叫 --save-interval,其余参数也都确认下,
命名和 megatron 对齐吧。
以及现在 main 里的参数有点多了,感觉之后可以考虑类似 megatron 那样把参数按不同类型分个组,先加个 TODO 记一下。
There was a problem hiding this comment.
做了如下替换
save_interval ----- save_steps
load ---- resume_from
save ---- checkpoint_dir
no_save_optim ---- save_optimizer_state // 语义是反的,默认值为false
checkpoint_format 已删除
| float best_loss = std::numeric_limits<float>::infinity(); | ||
| double last_lr = 0.0; | ||
| std::string optimizer_type = "unknown"; | ||
| std::string checkpoint_format = "bin"; |
|
|
||
| class Checkpoint { | ||
| public: | ||
| static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer, |
There was a problem hiding this comment.
optimizer 是不是改成指针更合适,当 save_optimizer_state 设置为 false 时,理论上不应该还必须传 optimizer。
There was a problem hiding this comment.
改为指针,当存optimizer state判断指针不为空
| void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer, | ||
| TrainerState *state, const CheckpointLoadOptions &options) { | ||
| CHECK(model != nullptr); | ||
| CHECK(state != nullptr); |
There was a problem hiding this comment.
既然不允许为空指针,那就签名直接要求传引用吧。
|
|
||
| // Create optimizer - use GetLoRAParameters if LoRA is enabled | ||
| std::vector<std::shared_ptr<Tensor>> params_to_optimize; | ||
| std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> named_params_to_optimize; |
There was a problem hiding this comment.
保留 optimizer 的 parameters 构造方法后,这段应该不需要修改了,llama3 同理。
There was a problem hiding this comment.
这里传入带name的参数就是为了存 optimizer 参数的时候保留name作为key,
|
建议补一个文档(飞书文档就行,不用放仓库里),介绍一下目前 pth 的格式与 torch/megatron ckpt 格式,解释一下现有格式与主流框架的区别、未来可能的兼容方式。 |
| const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict); | ||
|
|
||
| static std::unordered_map<std::string, std::shared_ptr<Tensor>> | ||
| LoadStateDictBinary(const std::filesystem::path &path); |
There was a problem hiding this comment.
binary 相关的 save/load 保留原有的功能函数形式就行,没必要做到 ckpt 里,ckpt 就只做 pth 的结构化存取就行。
There was a problem hiding this comment.
checkpoint 里不再支持binary格式,相关为了做format适配的flag 以及 options结构体都删除
There was a problem hiding this comment.
SaveStateDictBinary 和 LoadStateDictBinary 这两个函数目前还在 Checkpoint 里没有删除,LoadStateDictBinary 我理解就是对标的 example 目录下的 LoadFromLLMC 函数,所以没必要保留;SaveStateDictBinary 有必要的话,在 example 目录下加一个 SaveLLMC 函数用来支持 llmc 格式的保存就行。
There was a problem hiding this comment.
现在新增的格式后缀记成了 .ckpt, SaveStateDictBinary 和 LoadStateDictBinary 针对的就是 .ckpt。上次修改把LoadFromLLMC (对旧格式的读取)和 SaveLLMC 移出了 checkpoint机制,并且不再有 SaveLLMC 的调用,也在修改中删了。
| int pp_size = 1; | ||
| }; | ||
|
|
||
| struct CheckpointOptions { |
There was a problem hiding this comment.
这个可以直接删了,没必要这样兼容,ckpt 就只存取 pth。
|
|
||
| static void SaveTrainerState(const std::filesystem::path &path, const TrainerState &state); | ||
| static TrainerState LoadTrainerState(const std::filesystem::path &path); | ||
| static std::string InferFormat(const std::filesystem::path &checkpoint_dir); |
0b4857d to
8ebe12f
Compare
之前的文档有介绍现在的格式排布,待补充torch/megatron 用到的格式介绍。 |
e11cab7 to
ce292c7
Compare
c499288 to
ebaeadf
Compare
format: use clang-format-16 instead
remove redundent arguments
format files
- Use name-based optimizer state keys instead of index-based to
prevent state corruption from unordered_map traversal order
- Warn on unexpected keys when loading model state dict
- Validate parallel topology (TP/PP/SP) consistency on resume
- Add batch_idx alignment check for distributed data loader
- Default best_loss to infinity instead of zero
| const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict); | ||
|
|
||
| static std::unordered_map<std::string, std::shared_ptr<Tensor>> | ||
| LoadStateDictBinary(const std::filesystem::path &path); |
There was a problem hiding this comment.
SaveStateDictBinary 和 LoadStateDictBinary 这两个函数目前还在 Checkpoint 里没有删除,LoadStateDictBinary 我理解就是对标的 example 目录下的 LoadFromLLMC 函数,所以没必要保留;SaveStateDictBinary 有必要的话,在 example 目录下加一个 SaveLLMC 函数用来支持 llmc 格式的保存就行。
| CHECK(c.activation_type == nn::MLPType::kGELU) << "GPT-2 requires GELU activation"; | ||
| CHECK(c.norm_type == nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm"; | ||
| } | ||
|
|
| } | ||
| }; | ||
|
|
||
| collect("", this); |
There was a problem hiding this comment.
https://docs.pytorch.org/docs/2.12/generated/torch.nn.Module.html#torch.nn.Module.named_parameters
既然 collect 已经支持 prefix 了,干脆 NamedParameters 完全对齐 torch 接口吧,recurse 可以先不支持,在最开始 CHECK 一下只允许传 true。
| auto param = std::make_shared<Tensor>(std::vector<int64_t>{2, 3}, DataType::kFLOAT32, GetDevice()); | ||
| param->set_requires_grad(true); | ||
| params.push_back(param); | ||
| params.emplace_back(param); |
| auto param = std::make_shared<Tensor>(std::vector<int64_t>{2, 3}, DataType::kFLOAT32, GetDevice()); | ||
| param->set_requires_grad(true); | ||
| params.push_back(param); | ||
| params.emplace_back(param); |
There was a problem hiding this comment.
这里为什么都要改成 empace_back?
| model = std::make_shared<nn::TransformerModel>(model_config); | ||
| } | ||
| auto llmc_model = std::dynamic_pointer_cast<nn::TransformerModel>(model); | ||
| CHECK(llmc_model != nullptr) << "Failed to cast model to GPT2 for LLMC checkpoint I/O."; |
There was a problem hiding this comment.
这里的检查有必要吗?上面构造出来的就默认都是 TransformerModel 了吧,什么情况下会不是吗?
如果没必要就删了吧,llama3 同理。
8dc64b3 to
23a5e32
Compare
68b7c77 to
702a9bd
Compare
…, with plans to unify into one later.
702a9bd to
e8b1e37
Compare
1. checkpoint机制
From ArcaLunar:
Checkpoint 读取工具主要参数:
--save训练过程中的保存目录--save_interval每 N 次保存一次,设置为 0 则不保存--max_checkpoint_keep最多保留 K 个 checkpoint--no_save_optim是否保存优化器的状态--load从指定 checkpoint 目录恢复训练Checkpoint 文件可以通过从
/data/shared/....../llmc/gpt2(orllama3) 的原始模型参数训练而来,例子可见仓库中的 REPORT.md(Experiment 实际上也测试了llama3,但是命令只记录了 GPT2 训练),model.bin, optimizer.bin, trainer_state.json都可以从训练中获取.因此不在附件中提供Experiment
CUDA_VISIBLE_DEVICES=5,6,7 ./gpt2 --input_bin ../../data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath ../../data/llmc/gpt2/gpt2_124M.bin --save ../ckpt2/gpt2-noresume/ --num_iteration 100 --save_interval 20 --no_save_optim false --max_checkpoint_keep 10(以上两条训练命令同样用 llama3 也运行了)
运行 compare_loss.py,对于 llama3 模型,由于从 step 40 恢复训练,所以 step 1~40 数据缺失,而其余 60 步的 loss 在 FP32, BF16 下均吻合
对于 GPT2,模型保存的逻辑有误:训练中 lm_head 与 wte 并非真共享,而 LLMC 存取又按“共享”假设处理,resume 后 lm_head 很容易和 no resume 不一致。解决方法是把训练用 checkpoint 从 LLMC 回调路径切到原生 StateDict 二进制路径,并在加载后显式重建权重绑定语义 (
example/gpt2/main.cc).经过修复后,也可以通过.2. 训练对比
精度对比:

性能对比:
