Skip to content
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,19 @@ add_executable(gpt2
example/gpt2/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/gpt2/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
example/gpt2/checkpoint_loader.cc
)
link_infini_train_exe(gpt2)

add_executable(llama3
example/llama3/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/llama3/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
example/llama3/checkpoint_loader.cc
)
link_infini_train_exe(llama3)

Expand Down
119 changes: 119 additions & 0 deletions example/common/checkpoint_loader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "example/common/checkpoint_loader.h"

@chen2021673 chen2021673 May 8, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件太重了,有一千多行,checkpoint相关的基建和 llama / gpt 的 save / load 都混在一起了。要不要拆分一个example/common/checkpoint_utils.h/.cc,然后保留 gpt2 和 llama3 各自的特化调用?这个可以再讨论一下

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是按模型拆分吧,通用的公共函数放这里,gpt2/llama3 的特化部分放 example 下模型各自文件夹里。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


#include <cmath>
#include <cstdlib>
#include <filesystem>
#include <memory>
#include <string>
#include <vector>

#include "glog/logging.h"

#include "infini_train/include/nn/modules/transformer/transformer_config.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/tensor.h"

using namespace infini_train;
namespace nn = infini_train::nn;

// TODO(jym): ckpt is a new checkpoint format; bin is the legacy format. Keeping both as an interim solution; plan to
// consolidate into one later.
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
ResumeFromCheckpointResult result;
if (args.resume_root.empty()) {
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
return result;
}

int ddp_world_size = nn::parallel::global::GetDataParallelSize();
int tp_world_size = nn::parallel::global::GetTensorParallelSize();
int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1;
int pp_world_size = nn::parallel::global::GetPipelineParallelSize();

std::filesystem::path resume_dir = args.resume_root;
if (args.rank.IsParallel()) {
const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank());
if (std::filesystem::exists(rank_dir)) {
resume_dir = rank_dir;
}
}

Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state, true);

result.global_step = static_cast<int>(args.state.global_step);

CHECK_EQ(args.state.n_layer, args.model_config.n_layer)
<< "n_layer mismatch: ckpt=" << args.state.n_layer << ", config=" << args.model_config.n_layer;
CHECK_EQ(args.state.n_head, args.model_config.n_head)
<< "n_head mismatch: ckpt=" << args.state.n_head << ", config=" << args.model_config.n_head;
CHECK_EQ(args.state.n_kv_head, args.model_config.n_kv_head)
<< "n_kv_head mismatch: ckpt=" << args.state.n_kv_head << ", config=" << args.model_config.n_kv_head;
CHECK_EQ(args.state.n_embd, args.model_config.n_embd)
<< "n_embd mismatch: ckpt=" << args.state.n_embd << ", config=" << args.model_config.n_embd;
CHECK_EQ(args.state.vocab_size, args.model_config.vocab_size)
<< "vocab_size mismatch: ckpt=" << args.state.vocab_size << ", config=" << args.model_config.vocab_size;

CHECK_EQ(args.state.ddp_size, ddp_world_size) << "DDP size mismatch: checkpoint has DDP=" << args.state.ddp_size
<< ", but current run has DDP=" << ddp_world_size;
CHECK_EQ(args.state.tp_size, tp_world_size)
<< "TP size mismatch: checkpoint has TP=" << args.state.tp_size << ", but current run has TP=" << tp_world_size;
CHECK_EQ(args.state.sp_size, sp_world_size)
<< "SP size mismatch: checkpoint has SP=" << args.state.sp_size << ", but current run has SP=" << sp_world_size;
CHECK_EQ(args.state.pp_size, pp_world_size)
<< "PP size mismatch: checkpoint has PP=" << args.state.pp_size << ", but current run has PP=" << pp_world_size;

result.consumed_batches = static_cast<size_t>(std::max<int64_t>(args.state.consumed_batches, 0));
if (args.rank.IsMainRank()) {
LOG(INFO) << std::format("Resume training from step {}, last_lr {:.3e}, consumed_batches {}",
args.state.global_step, args.state.last_lr, args.state.consumed_batches);
}

return result;
}

void SaveCheckpoint(const SaveCheckpointArgs &args) {
const auto ckpt_start = std::chrono::high_resolution_clock::now();

TrainerState state;
state.global_step = args.global_step;
state.consumed_batches = static_cast<int64_t>(args.consumed_batches);
state.last_lr = args.last_lr;
state.n_layer = args.n_layer;
state.n_head = args.n_head;
state.n_kv_head = args.n_kv_head;
state.n_embd = args.n_embd;
state.vocab_size = args.vocab_size;
state.ddp_size = args.ddp_size;
state.tp_size = args.tp_size;
state.sp_size = args.sp_size;
state.pp_size = args.pp_size;

Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state, args.no_save_optim);

const auto ckpt_end = std::chrono::high_resolution_clock::now();
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();

if (!args.rank.IsMainRank()) {
return;
}

LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms);

if (!args.prune_step_checkpoints) {
return;
}

std::vector<std::filesystem::path> ckpts;
if (std::filesystem::exists(args.checkpoint_root_dir)) {
for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) {
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
ckpts.push_back(entry.path());
}
}
std::sort(ckpts.begin(), ckpts.end());
while (ckpts.size() > args.max_checkpoint_keep) {
std::filesystem::remove_all(ckpts.front());
ckpts.erase(ckpts.begin());
}
}
}
59 changes: 59 additions & 0 deletions example/common/checkpoint_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#pragma once

#include <cstdint>
#include <cstring>
#include <filesystem>

#include "infini_train/include/checkpoint.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/optimizer.h"

using namespace infini_train;
namespace nn = infini_train::nn;

namespace infini_train::nn {
class TransformerConfig;
}

struct ResumeFromCheckpointArgs {
std::filesystem::path resume_root;
const nn::parallel::Rank &rank;
std::shared_ptr<nn::Module> model;
std::shared_ptr<Optimizer> optimizer;
const nn::TransformerConfig &model_config;
TrainerState &state;
};

struct ResumeFromCheckpointResult {
int global_step = 0;
size_t consumed_batches = 0;
};

struct SaveCheckpointArgs {
std::filesystem::path save_dir;
int64_t global_step = 0;
size_t consumed_batches = 0;
double last_lr = 0.0;
int64_t n_layer = 0;
int64_t n_head = 0;
int64_t n_kv_head = 0;
int64_t n_embd = 0;
int64_t vocab_size = 0;
int ddp_size = 1;
int tp_size = 1;
int sp_size = 1;
int pp_size = 1;
bool no_save_optim = false;
bool prune_step_checkpoints = false;
std::filesystem::path checkpoint_root_dir;
size_t max_checkpoint_keep = 0;
const nn::parallel::Rank &rank;
const nn::Module &model;
const Optimizer &optimizer;
};

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);

void SaveCheckpoint(const SaveCheckpointArgs &args);
5 changes: 3 additions & 2 deletions example/gpt2/checkpoint_loader.cc
Comment thread
kilinchange marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include "glog/logging.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/sparse.h"
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
Expand All @@ -24,6 +22,9 @@
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"

using namespace infini_train;
namespace nn = infini_train::nn;

Expand Down
75 changes: 74 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <chrono>
#include <cstdlib>
#include <filesystem>
#include <format>
#include <memory>
#include <optional>
Expand All @@ -10,6 +11,7 @@
#include "glog/logging.h"

#include "infini_train/include/autocast.h"
#include "infini_train/include/checkpoint.h"
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
Expand All @@ -34,11 +36,13 @@
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"

#include "example/common/checkpoint_loader.h"
#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
#include "example/gpt2/checkpoint_loader.h"
#include "example/gpt2/config.h"

// TODO(jym): Reorganize CLI flags into categories for better readability and maintainability.
// I/O
DEFINE_string(input_bin, "", "input .bin to train on");
DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on");
Expand Down Expand Up @@ -77,6 +81,11 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables saving");
DEFINE_string(load, "", "checkpoint directory to resume from");
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
DEFINE_bool(no_save_optim, false, "whether optimizer state is persisted in checkpoints");
// precision check
DEFINE_string(
precision_check, "",
Expand Down Expand Up @@ -315,9 +324,56 @@ void Train(const nn::parallel::Rank &rank) {

auto impl = core::GetDeviceGuardImpl(device.type());

int start_step = 0;
TrainerState state;
const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load,
.rank = rank,
.model = model,
.optimizer = optimizer,
.model_config = model_config,
.state = state});
start_step = resume_result.global_step;
size_t consumed_batches = resume_result.consumed_batches;

// TODO(jym): Replace with Sampler abstraction when available.
// Skip dataloader to resume from the correct batch position.
if (consumed_batches > 0) {
size_t start = train_iter.BatchIndex();
// Each rank processes every ddp_world_size-th batch starting from its own rank.
// num_skips calculates how many ++ iterations to reach the saved batch position.
size_t num_skips = (consumed_batches - start) / ddp_world_size;
for (size_t i = 0; i < num_skips; ++i) { ++train_iter; }
}

auto save_checkpoint
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
SaveCheckpoint({
.save_dir = save_dir,
.global_step = global_step,
.consumed_batches = consumed_batches,
.last_lr = FLAGS_learning_rate,
.n_layer = model_config.n_layer,
.n_head = model_config.n_head,
.n_kv_head = model_config.n_kv_head,
.n_embd = model_config.n_embd,
.vocab_size = model_config.vocab_size,
.ddp_size = ddp_world_size,
.tp_size = tp_world_size,
.sp_size = sp_world_size,
.pp_size = pp_world_size,
.no_save_optim = FLAGS_no_save_optim,
.prune_step_checkpoints = prune_step_checkpoints,
.checkpoint_root_dir = FLAGS_save,
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
.rank = rank,
.model = *model,
.optimizer = *optimizer,
});
};

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

Expand Down Expand Up @@ -367,6 +423,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand Down Expand Up @@ -397,6 +454,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
consumed_batches = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand Down Expand Up @@ -431,6 +489,15 @@ void Train(const nn::parallel::Rank &rank) {
}
}
}

if (FLAGS_save_interval > 0 && (step + 1) % FLAGS_save_interval == 0) {
std::filesystem::path step_dir
= std::filesystem::path(FLAGS_save) / std::format("checkpoint_step_{:06d}", step + 1);
if (rank.IsParallel()) {
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(step_dir, step + 1, true);
}
}

// Save LoRA weights if enabled and path specified
Expand All @@ -439,6 +506,12 @@ void Train(const nn::parallel::Rank &rank) {
nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path);
}

std::filesystem::path final_dir = std::filesystem::path(FLAGS_save) / "checkpoint_final";
if (rank.IsParallel()) {
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(final_dir, FLAGS_num_iteration, false);

#ifdef PROFILE_MODE
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);
Profiler::Instance().PrintRecords("gpt2.records.log");
Expand Down
5 changes: 3 additions & 2 deletions example/llama3/checkpoint_loader.cc
Comment thread
kilinchange marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include "glog/logging.h"

#include "example/common/utils.h"
#include "example/llama3/config.h"
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
#include "infini_train/include/nn/modules/transformer/mlp.h"
Expand All @@ -22,6 +20,9 @@
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

#include "example/common/utils.h"
#include "example/llama3/config.h"

using namespace infini_train;
namespace nn = infini_train::nn;

Expand Down
Loading
Loading