From a829a4fffc46097f4fa6423deb2624341830b02c Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 11 May 2026 11:32:03 +0000 Subject: [PATCH 1/9] feat: implement MoE infrastructure --- infini_train/include/autograd/moe.h | 26 + .../nn/modules/transformer/moe/experts.h | 25 + .../nn/modules/transformer/moe/moe_layer.h | 25 + .../nn/modules/transformer/moe/moe_utils.h | 9 + .../nn/modules/transformer/moe/router.h | 25 + .../modules/transformer/transformer_config.h | 33 + infini_train/src/autograd/moe.cc | 31 + infini_train/src/kernels/cpu/top1_mask.cc | 67 ++ infini_train/src/kernels/cuda/top1_mask.cu | 107 ++++ .../src/nn/modules/transformer/moe/experts.cc | 50 ++ .../nn/modules/transformer/moe/moe_layer.cc | 32 + .../nn/modules/transformer/moe/moe_utils.cc | 12 + .../src/nn/modules/transformer/moe/router.cc | 50 ++ .../src/nn/modules/transformer/transformer.cc | 7 +- .../test_transformer_architecture.cc | 600 ++++++++++++++++++ 15 files changed, 1098 insertions(+), 1 deletion(-) create mode 100644 infini_train/include/autograd/moe.h create mode 100644 infini_train/include/nn/modules/transformer/moe/experts.h create mode 100644 infini_train/include/nn/modules/transformer/moe/moe_layer.h create mode 100644 infini_train/include/nn/modules/transformer/moe/moe_utils.h create mode 100644 infini_train/include/nn/modules/transformer/moe/router.h create mode 100644 infini_train/src/autograd/moe.cc create mode 100644 infini_train/src/kernels/cpu/top1_mask.cc create mode 100644 infini_train/src/kernels/cuda/top1_mask.cu create mode 100644 infini_train/src/nn/modules/transformer/moe/experts.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/moe_layer.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/moe_utils.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/router.cc create mode 100644 test/transformer/test_transformer_architecture.cc diff --git a/infini_train/include/autograd/moe.h b/infini_train/include/autograd/moe.h new file mode 100644 index 00000000..5317de8e --- /dev/null +++ b/infini_train/include/autograd/moe.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Top1Mask : public Function { +public: + static constexpr char kType[] = "Top1MaskFunction"; + + Top1Mask() : Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/moe/experts.h b/infini_train/include/nn/modules/transformer/moe/experts.h new file mode 100644 index 00000000..a3dda7f0 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/experts.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class SequentialMLP : public CloneableModule { +public: + static constexpr char kType[] = "SequentialMLP"; + static constexpr char kExpertNamePrefix[] = "expert_"; + + explicit SequentialMLP(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_layer.h b/infini_train/include/nn/modules/transformer/moe/moe_layer.h new file mode 100644 index 00000000..e5fdb3ab --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_layer.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class MoELayer : public CloneableModule { +public: + static constexpr char kType[] = "MoELayer"; + static constexpr char kRouterLayerName[] = "router"; + static constexpr char kExpertsLayerName[] = "experts"; + + explicit MoELayer(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h new file mode 100644 index 00000000..e0dd3744 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config); + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/router.h b/infini_train/include/nn/modules/transformer/moe/router.h new file mode 100644 index 00000000..1279c217 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/router.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class TopKRouter : public CloneableModule { +public: + static constexpr char kType[] = "TopKRouter"; + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + + explicit TopKRouter(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 448e7b30..bf1a00fc 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -15,11 +15,42 @@ enum class MLPType { kSwiGLU // SwiGLU activation }; +enum class FFNType { + kDense, // Standard dense MLP + kMoE // Mixture-of-Experts MLP +}; + enum class NormType { kLayerNorm, // LayerNorm kRMSNorm // RMSNorm }; +enum class MoERouterType { + kTopK // Top-k router. The initial implementation supports top-1. +}; + +enum class MoEDispatcherType { + kLocal, // No cross-rank token exchange + kAllGather // Reserved for expert parallel MoE +}; + +enum class MoEExpertImpl { + kSequential // Run local experts sequentially +}; + +struct MoEConfig { + int64_t num_experts = 0; + int64_t expert_parallel_size = 1; + int64_t router_topk = 1; + float aux_loss_coeff = 0.0f; + std::optional expert_capacity_factor = std::nullopt; + bool pad_expert_input_to_capacity = false; + int64_t moe_ffn_hidden_size = 0; + MoERouterType router_type = MoERouterType::kTopK; + MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal; + MoEExpertImpl expert_impl = MoEExpertImpl::kSequential; +}; + struct TransformerConfig { int64_t block_size = 1024; // Max seq_len int64_t vocab_size = 50304; // Vocab size @@ -31,6 +62,7 @@ struct TransformerConfig { AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type MLPType activation_type = MLPType::kGELU; // MLP activation type + FFNType ffn_type = FFNType::kDense; // Feed-forward module type NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, @@ -43,6 +75,7 @@ struct TransformerConfig { float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier int64_t multiple_of = 256; // FFN dims must be multiple of this number + std::optional moe_config = std::nullopt; // RoPE config float rope_theta = 500000.0f; // theta in RoPE diff --git a/infini_train/src/autograd/moe.cc b/infini_train/src/autograd/moe.cc new file mode 100644 index 00000000..05134e82 --- /dev/null +++ b/infini_train/src/autograd/moe.cc @@ -0,0 +1,31 @@ +#include "infini_train/include/autograd/moe.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> Top1Mask::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "Top1MaskForward"}, input)}; +} + +void Top1Mask::SetupContext(const std::vector> &, + const std::vector> &output_tensors) { + saved_tensors_ = {output_tensors[0]}; +} + +std::vector> Top1Mask::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &mask_values = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "Top1MaskBackward"}, grad_output, mask_values)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/top1_mask.cc b/infini_train/src/kernels/cpu/top1_mask.cc new file mode 100644 index 00000000..d6ae91d6 --- /dev/null +++ b/infini_train/src/kernels/cpu/top1_mask.cc @@ -0,0 +1,67 @@ +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + output->Fill(0.0f); + + const float *in = static_cast(input->DataPtr()); + float *out = static_cast(output->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + int64_t best_idx = 0; + float best_value = in[row * num_experts]; + for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + const float value = in[row * num_experts + expert_idx]; + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + out[row * num_experts + best_idx] = best_value; + } + + return output; +} + +std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only"; + CHECK(mask_values->Dtype() == DataType::kFLOAT32); + CHECK(grad_output->Dims() == mask_values->Dims()); + + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + grad_input->Fill(0.0f); + + const float *grad = static_cast(grad_output->DataPtr()); + const float *mask = static_cast(mask_values->DataPtr()); + float *out = static_cast(grad_input->DataPtr()); + for (int64_t i = 0; i < static_cast(grad_output->NumElements()); ++i) { + out[i] = mask[i] != 0.0f ? grad[i] : 0.0f; + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward) +REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward) + +#undef REGISTER_CPU_TOP1_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/top1_mask.cu b/infini_train/src/kernels/cuda/top1_mask.cu new file mode 100644 index 00000000..8fd00c91 --- /dev/null +++ b/infini_train/src/kernels/cuda/top1_mask.cu @@ -0,0 +1,107 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, + int64_t num_experts) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t offset = row * num_experts; + int64_t best_idx = 0; + float best_value = static_cast(input[offset]); + for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + const float value = static_cast(input[offset + expert_idx]); + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f); + } +} + +std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { + CHECK_GE(input->Dims().size(), 1); + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + Top1MaskForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts); + }, + "CUDA Top1MaskForward"); + + return output; +} + +template +__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, + T *__restrict__ grad_input, int64_t total_elements) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_elements) { + return; + } + grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); +} + +std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dims() == mask_values->Dims()); + CHECK(grad_output->Dtype() == mask_values->Dtype()); + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int64_t total_elements = grad_output->NumElements(); + const int threads = 256; + const int blocks = static_cast((total_elements + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + grad_output->Dtype(), + [=]() { + Top1MaskBackwardKernel<<>>( + static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), + static_cast(grad_input->DataPtr()), total_elements); + }, + "CUDA Top1MaskBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward) +REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward) + +#undef REGISTER_CUDA_TOP1_MASK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc new file mode 100644 index 00000000..8f3b1be8 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/experts.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential); + CHECK_EQ(moe_config.expert_parallel_size, 1) + << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + num_local_experts_ = moe_config.num_experts; + CHECK_GT(num_local_experts_, 0); + + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + modules_[std::string(kExpertNamePrefix) + std::to_string(expert_idx)] = std::make_shared(config_); + } +} + +std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + auto hidden_states = input_tensors[0]; + auto routing_probs = input_tensors[1]; + CHECK_EQ(routing_probs->Dims().back(), num_local_experts_); + + std::shared_ptr output = nullptr; + const int64_t expert_dim = static_cast(routing_probs->Dims().size()) - 1; + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); + auto expert_output = (*modules_.at(expert_name))({hidden_states})[0]; + auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1); + auto weighted_output = expert_output * expert_prob; + output = output == nullptr ? weighted_output : output + weighted_output; + } + + return {output}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc new file mode 100644 index 00000000..8efd51c0 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -0,0 +1,32 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/moe/experts.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(config_.ffn_type == FFNType::kMoE); + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + modules_[kRouterLayerName] = std::make_shared(config_); + modules_[kExpertsLayerName] = std::make_shared(config_); +} + +std::vector> MoELayer::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + auto hidden_states = input_tensors[0]; + auto routing_probs = (*modules_.at(kRouterLayerName))({hidden_states})[0]; + return (*modules_.at(kExpertsLayerName))({hidden_states, routing_probs}); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc new file mode 100644 index 00000000..80ef01c1 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -0,0 +1,12 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" + +#include "glog/logging.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { + CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; + return config.moe_config.value(); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc new file mode 100644 index 00000000..59dec209 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/router.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.router_type == MoERouterType::kTopK); + CHECK_EQ(moe_config.router_topk, 1) << "Current InfiniTrain MoE implementation supports top-1 routing only"; + CHECK_GT(moe_config.num_experts, 0); + + parameters_[kParamWeightName] + = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + init::KaimingUniform(parameters_[kParamWeightName]); + + if (config_.add_bias_linear) { + parameters_[kParamBiasName] + = std::make_shared(std::vector{moe_config.num_experts}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + parameters_[kParamBiasName]->Fill(0.0f); + } +} + +std::vector> TopKRouter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + std::vector> linear_inputs{input_tensors[0], parameters_.at(kParamWeightName)}; + if (parameters_.contains(kParamBiasName)) { + linear_inputs.push_back(parameters_.at(kParamBiasName)); + } + + auto logits = std::make_shared()->Apply(linear_inputs)[0]; + auto scores = function::Softmax(logits, -1); + auto routing_probs = std::make_shared()->Apply({scores})[0]; + return {routing_probs}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..bdcde449 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -86,7 +87,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea } modules_[kAttnLayerName] = std::make_shared(config); - modules_[kMlpLayerName] = std::make_shared(config); + if (config.ffn_type == FFNType::kMoE) { + modules_[kMlpLayerName] = std::make_shared(config); + } else { + modules_[kMlpLayerName] = std::make_shared(config); + } } std::vector> TransformerLayer::Forward(const std::vector> &x) { diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc new file mode 100644 index 00000000..da3dd70e --- /dev/null +++ b/test/transformer/test_transformer_architecture.cc @@ -0,0 +1,600 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "example/gpt2/config.h" +#include "example/llama3/config.h" +#include "infini_train/include/nn/modules/activations.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/sparse.h" +#include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +// ============================================================================ +// Test 1: TransformerConfig Validation +// ============================================================================ +void TestConfigValidation() { + std::cout << "\n=== Test 1: TransformerConfig Validation ===" << std::endl; + + bool all_passed = true; + + // Test GPT2 config + auto gpt2_config = gpt2::GPT2Config(); + if (gpt2_config.attention_type != nn::AttentionType::kStandard) { + std::cout << "FAIL: GPT2 config should use Standard attention" << std::endl; + all_passed = false; + } + if (gpt2_config.activation_type != nn::MLPType::kGELU) { + std::cout << "FAIL: GPT2 config should use GELU activation" << std::endl; + all_passed = false; + } + if (gpt2_config.norm_type != nn::NormType::kLayerNorm) { + std::cout << "FAIL: GPT2 config should use LayerNorm" << std::endl; + all_passed = false; + } + if (!gpt2_config.add_bias_linear) { + std::cout << "FAIL: GPT2 config should have bias enabled" << std::endl; + all_passed = false; + } + if (!gpt2_config.tie_weights) { + std::cout << "FAIL: GPT2 config should have tied weights" << std::endl; + all_passed = false; + } + + // Test LLaMA3 config + auto llama3_config = llama3::LLaMA3Config(); + if (llama3_config.attention_type != nn::AttentionType::kRoPE) { + std::cout << "FAIL: LLaMA3 config should use RoPE attention" << std::endl; + all_passed = false; + } + if (llama3_config.activation_type != nn::MLPType::kSwiGLU) { + std::cout << "FAIL: LLaMA3 config should use SwiGLU activation" << std::endl; + all_passed = false; + } + if (llama3_config.norm_type != nn::NormType::kRMSNorm) { + std::cout << "FAIL: LLaMA3 config should use RMSNorm" << std::endl; + all_passed = false; + } + if (llama3_config.add_bias_linear) { + std::cout << "FAIL: LLaMA3 config should have bias disabled" << std::endl; + all_passed = false; + } + if (llama3_config.tie_weights) { + std::cout << "FAIL: LLaMA3 config should not have tied weights" << std::endl; + all_passed = false; + } + + // Test GQA detection + if (!llama3_config.UseGQA()) { + std::cout << "FAIL: LLaMA3 config should detect GQA (n_kv_head < n_head)" << std::endl; + all_passed = false; + } + if (gpt2_config.UseGQA()) { + std::cout << "FAIL: GPT2 config should not detect GQA (n_kv_head == n_head)" << std::endl; + all_passed = false; + } + + if (all_passed) { + std::cout << "SUCCESS: All config validations passed!" << std::endl; + } +} + +// ============================================================================ +// Test 2: Embedding Layer +// ============================================================================ +void TestEmbedding() { + std::cout << "\n=== Test 2: Embedding Layer ===" << std::endl; + + const int64_t vocab_size = 1000; + const int64_t embedding_dim = 128; + const int64_t batch_size = 2; + const int64_t seq_len = 16; + + try { + auto embedding = std::make_shared(vocab_size, embedding_dim); + + // Check parameters + auto params = embedding->Parameters(); + if (params.size() != 1) { + std::cout << "FAIL: Embedding should have 1 parameter, got " << params.size() << std::endl; + return; + } + + // Check weight shape + auto weight = embedding->parameter(nn::Embedding::kParamWeightName); + if (weight->Dims() != std::vector{vocab_size, embedding_dim}) { + std::cout << "FAIL: Embedding weight shape mismatch" << std::endl; + return; + } + + // Forward pass + auto input = std::make_shared(std::vector{batch_size, seq_len}, DataType::kINT64); + auto output = (*embedding)({input}); + + if (output.size() != 1) { + std::cout << "FAIL: Embedding forward should return 1 tensor" << std::endl; + return; + } + + const auto &out_dims = output[0]->Dims(); + if (out_dims != std::vector{batch_size, seq_len, embedding_dim}) { + std::cout << "FAIL: Embedding output shape mismatch. Expected [" << batch_size << ", " << seq_len << ", " + << embedding_dim << "], got [" << out_dims[0] << ", " << out_dims[1] << ", " << out_dims[2] << "]" + << std::endl; + return; + } + + std::cout << "SUCCESS: Embedding layer works correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 3: Normalization Layers (LayerNorm vs RMSNorm) +// ============================================================================ +void TestNormalization() { + std::cout << "\n=== Test 3: Normalization Layers ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + + try { + // Test LayerNorm + auto layernorm = std::make_shared(std::vector{hidden_size}); + auto ln_params = layernorm->Parameters(); + if (ln_params.size() != 2) { + std::cout << "FAIL: LayerNorm should have 2 parameters (weight, bias), got " << ln_params.size() + << std::endl; + return; + } + + // Test RMSNorm + auto rmsnorm = std::make_shared(hidden_size); + auto rms_params = rmsnorm->Parameters(); + if (rms_params.size() != 1) { + std::cout << "FAIL: RMSNorm should have 1 parameter (weight), got " << rms_params.size() << std::endl; + return; + } + + // Forward pass for both + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto ln_output = (*layernorm)({input}); + auto rms_output = (*rmsnorm)({input}); + + if (ln_output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: LayerNorm output shape mismatch" << std::endl; + return; + } + + if (rms_output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: RMSNorm output shape mismatch" << std::endl; + return; + } + + std::cout << "SUCCESS: Normalization layers work correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 4: MLP Layer (GELU vs SwiGLU) +// ============================================================================ +void TestMlp() { + std::cout << "\n=== Test 4: MLP Layer ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + + try { + // Test GPT2-style MLP (GELU) + nn::TransformerConfig gpt2_mlp_config; + gpt2_mlp_config.n_embd = hidden_size; + gpt2_mlp_config.activation_type = nn::MLPType::kGELU; + gpt2_mlp_config.ffn_expansion_ratio = 4.0f; + gpt2_mlp_config.add_bias_linear = true; + + auto gpt2_mlp = std::make_shared(gpt2_mlp_config); + auto gpt2_params = gpt2_mlp->Parameters(); + + // GPT2 MLP should have: c_fc.weight, c_fc.bias, c_proj.weight, c_proj.bias + if (gpt2_params.size() != 4) { + std::cout << "FAIL: GPT2 MLP should have 4 parameters, got " << gpt2_params.size() << std::endl; + return; + } + + // Test LLaMA3-style MLP (SwiGLU) + nn::TransformerConfig llama3_mlp_config; + llama3_mlp_config.n_embd = hidden_size; + llama3_mlp_config.activation_type = nn::MLPType::kSwiGLU; + llama3_mlp_config.ffn_expansion_ratio = 4.0f; + llama3_mlp_config.add_bias_linear = false; + llama3_mlp_config.ffn_dim_multiplier = 1.5f; + llama3_mlp_config.multiple_of = 256; + + auto llama3_mlp = std::make_shared(llama3_mlp_config); + auto llama3_params = llama3_mlp->Parameters(); + + // LLaMA3 MLP should have: c_fc.weight, c_fc2.weight, c_proj.weight (no bias) + if (llama3_params.size() != 3) { + std::cout << "FAIL: LLaMA3 MLP should have 3 parameters, got " << llama3_params.size() << std::endl; + return; + } + + // Forward pass + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto gpt2_output = (*gpt2_mlp)({input}); + auto llama3_output = (*llama3_mlp)({input}); + + // Output should have same hidden dimension + if (gpt2_output[0]->Dims()[2] != hidden_size) { + std::cout << "FAIL: GPT2 MLP output hidden dim mismatch" << std::endl; + return; + } + + if (llama3_output[0]->Dims()[2] != hidden_size) { + std::cout << "FAIL: LLaMA3 MLP output hidden dim mismatch" << std::endl; + return; + } + + std::cout << "SUCCESS: MLP layers work correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 5: CausalSelfAttention +// ============================================================================ +void TestAttention() { + std::cout << "\n=== Test 5: CausalSelfAttention ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + const int64_t n_head = 4; + + try { + // Test standard attention (GPT2-style) + nn::TransformerConfig standard_config; + standard_config.n_embd = hidden_size; + standard_config.n_head = n_head; + standard_config.n_kv_head = n_head; + standard_config.attention_type = nn::AttentionType::kStandard; + standard_config.add_bias_linear = true; + + auto standard_attn = std::make_shared(standard_config); + auto standard_params = standard_attn->Parameters(); + + // Should have c_attn (QKV combined) and c_proj with biases + if (standard_params.size() != 4) { + std::cout << "FAIL: Standard attention should have 4 parameters, got " << standard_params.size() + << std::endl; + return; + } + + // Test RoPE attention with GQA (LLaMA3-style) + nn::TransformerConfig rope_config; + rope_config.n_embd = hidden_size; + rope_config.n_head = n_head; + rope_config.n_kv_head = 2; // GQA: fewer KV heads + rope_config.attention_type = nn::AttentionType::kRoPE; + rope_config.add_bias_linear = false; + + auto rope_attn = std::make_shared(rope_config); + auto rope_params = rope_attn->Parameters(); + + // RoPE attention without bias should have fewer params + if (rope_params.empty()) { + std::cout << "FAIL: RoPE attention should have parameters" << std::endl; + return; + } + + // Forward pass + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto standard_output = (*standard_attn)({input}); + if (standard_output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: Standard attention output shape mismatch" << std::endl; + return; + } + + std::cout << "SUCCESS: CausalSelfAttention works correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 6: TransformerLayer +// ============================================================================ +void TestTransformerLayer() { + std::cout << "\n=== Test 6: TransformerLayer ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + + try { + // Test GPT2-style layer + auto gpt2_config = gpt2::GPT2Config(); + gpt2_config.n_embd = hidden_size; + gpt2_config.n_head = 4; + gpt2_config.n_layer = 1; + + auto gpt2_layer = std::make_shared(gpt2_config); + auto gpt2_params = gpt2_layer->Parameters(); + + if (gpt2_params.empty()) { + std::cout << "FAIL: GPT2 TransformerLayer should have parameters" << std::endl; + return; + } + + // Forward pass + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto output = (*gpt2_layer)({input}); + if (output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: TransformerLayer output shape mismatch" << std::endl; + return; + } + + // Test LLaMA3-style layer + auto llama3_config = llama3::LLaMA3Config(); + llama3_config.n_embd = hidden_size; + llama3_config.n_head = 4; + llama3_config.n_kv_head = 2; + llama3_config.n_layer = 1; + + auto llama3_layer = std::make_shared(llama3_config); + auto llama3_params = llama3_layer->Parameters(); + + if (llama3_params.empty()) { + std::cout << "FAIL: LLaMA3 TransformerLayer should have parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: TransformerLayer works correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 7: TransformerModel Instantiation (GPT2) +// ============================================================================ +void TestGpt2Model() { + std::cout << "\n=== Test 7: GPT2 Model Instantiation ===" << std::endl; + + auto config = gpt2::GPT2Config(); + // Use smaller config for faster testing + config.n_layer = 2; + config.n_head = 4; + config.n_embd = 64; + + try { + auto model = std::make_shared(config); + + if (model == nullptr) { + std::cout << "FAIL: Failed to create GPT2 model" << std::endl; + return; + } + + auto params = model->Parameters(); + if (params.empty()) { + std::cout << "FAIL: GPT2 model has no parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: GPT2 model created with " << params.size() << " parameters!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 8: TransformerModel Instantiation (LLaMA3) +// ============================================================================ +void TestLlama3Model() { + std::cout << "\n=== Test 8: LLaMA3 Model Instantiation ===" << std::endl; + + auto config = llama3::LLaMA3Config(); + // Use smaller config for faster testing + config.n_layer = 2; + config.n_head = 4; + config.n_kv_head = 2; + config.n_embd = 64; + + try { + auto model = std::make_shared(config); + + if (model == nullptr) { + std::cout << "FAIL: Failed to create LLaMA3 model" << std::endl; + return; + } + + auto params = model->Parameters(); + if (params.empty()) { + std::cout << "FAIL: LLaMA3 model has no parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: LLaMA3 model created with " << params.size() << " parameters!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 9: RoPE Utilities +// ============================================================================ +void TestRopeUtils() { + std::cout << "\n=== Test 9: RoPE Utilities ===" << std::endl; + + const int64_t head_dim = 64; + const int64_t seq_len = 128; + + try { + // Test precompute freqs_cis + auto freqs_cis = PrecomputeFreqsCis(head_dim, seq_len); + + // freqs_cis shape: [seq_len, head_dim/2, 2] (cos and sin stacked on last dim) + const auto &dims = freqs_cis->Dims(); + if (dims.size() != 3) { + std::cout << "FAIL: freqs_cis should be 3D, got " << dims.size() << "D" << std::endl; + return; + } + if (dims[0] != seq_len) { + std::cout << "FAIL: freqs_cis seq_len mismatch. Expected " << seq_len << ", got " << dims[0] << std::endl; + return; + } + if (dims[1] != head_dim / 2) { + std::cout << "FAIL: freqs_cis head_dim/2 mismatch. Expected " << head_dim / 2 << ", got " << dims[1] + << std::endl; + return; + } + if (dims[2] != 2) { + std::cout << "FAIL: freqs_cis last dim should be 2 (cos, sin), got " << dims[2] << std::endl; + return; + } + + std::cout << "SUCCESS: RoPE utilities work correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 10: Model StateDict +// ============================================================================ +void TestStateDict() { + std::cout << "\n=== Test 10: Model StateDict ===" << std::endl; + + nn::TransformerConfig config; + config.n_layer = 1; + config.n_head = 2; + config.n_kv_head = 2; // Must set explicitly + config.n_embd = 32; + config.vocab_size = 1000; + config.attention_type = nn::AttentionType::kStandard; + config.activation_type = nn::MLPType::kGELU; + config.norm_type = nn::NormType::kLayerNorm; + config.add_bias_linear = true; + + try { + auto model = std::make_shared(config); + auto state_dict = model->StateDict(); + + if (state_dict.empty()) { + std::cout << "FAIL: StateDict should not be empty" << std::endl; + return; + } + + // StateDict includes both parameters and buffers, so it should have >= parameters count + auto params = model->Parameters(); + auto buffers = model->Buffers(); + + if (state_dict.size() < params.size()) { + std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should be >= parameter count (" + << params.size() << ")" << std::endl; + return; + } + + // Expected: state_dict.size() == params.size() + buffers.size() + size_t expected_size = params.size() + buffers.size(); + if (state_dict.size() != expected_size) { + std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should equal params (" << params.size() + << ") + buffers (" << buffers.size() << ") = " << expected_size << std::endl; + return; + } + + std::cout << "SUCCESS: StateDict works correctly with " << state_dict.size() << " entries (" << params.size() + << " params + " << buffers.size() << " buffers)!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 11: MoE Layer MVP +// ============================================================================ +void TestMoELayer() { + std::cout << "\n=== Test 11: MoE Layer MVP ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 2; + config.moe_config->router_topk = 1; + + try { + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + if (output.size() != 1) { + std::cout << "FAIL: MoELayer forward should return 1 tensor" << std::endl; + return; + } + if (output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: MoELayer output shape mismatch" << std::endl; + return; + } + + auto params = moe->Parameters(); + if (params.empty()) { + std::cout << "FAIL: MoELayer should own router and expert parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: MoE layer MVP forward works correctly!" << std::endl; + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); + + std::cout << "========================================" << std::endl; + std::cout << " Transformer architecture Tests" << std::endl; + std::cout << "========================================" << std::endl; + + TestConfigValidation(); + TestEmbedding(); + TestNormalization(); + TestMlp(); + TestAttention(); + TestTransformerLayer(); + TestGpt2Model(); + TestLlama3Model(); + TestRopeUtils(); + TestStateDict(); + TestMoELayer(); + + std::cout << "\n========================================" << std::endl; + std::cout << " All Tests Completed" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +} From 599853ecf432c87fd37a64581176a43c562793aa Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 13 May 2026 03:01:31 +0000 Subject: [PATCH 2/9] feat: support topk_router --- .../include/autograd/{moe.h => topk_mask.h} | 9 ++- .../modules/transformer/transformer_config.h | 2 +- .../src/autograd/{moe.cc => topk_mask.cc} | 13 ++-- .../cpu/{top1_mask.cc => topk_mask.cc} | 53 ++++++++++++----- .../cuda/{top1_mask.cu => topk_mask.cu} | 55 ++++++++++------- .../src/nn/modules/transformer/moe/router.cc | 8 ++- .../test_transformer_architecture.cc | 59 ++++++++++++------- 7 files changed, 126 insertions(+), 73 deletions(-) rename infini_train/include/autograd/{moe.h => topk_mask.h} (76%) rename infini_train/src/autograd/{moe.cc => topk_mask.cc} (70%) rename infini_train/src/kernels/cpu/{top1_mask.cc => topk_mask.cc} (50%) rename infini_train/src/kernels/cuda/{top1_mask.cu => topk_mask.cu} (66%) diff --git a/infini_train/include/autograd/moe.h b/infini_train/include/autograd/topk_mask.h similarity index 76% rename from infini_train/include/autograd/moe.h rename to infini_train/include/autograd/topk_mask.h index 5317de8e..355ef400 100644 --- a/infini_train/include/autograd/moe.h +++ b/infini_train/include/autograd/topk_mask.h @@ -11,16 +11,19 @@ class Tensor; namespace infini_train::autograd { -class Top1Mask : public Function { +class TopKMask : public Function { public: - static constexpr char kType[] = "Top1MaskFunction"; + static constexpr char kType[] = "TopKMaskFunction"; - Top1Mask() : Function(kType) {} + explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {} std::vector> Forward(const std::vector> &input_tensors) override; void SetupContext(const std::vector> &input_tensors, const std::vector> &output_tensors) override; std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t topk_ = 1; }; } // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index bf1a00fc..2072acb6 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -26,7 +26,7 @@ enum class NormType { }; enum class MoERouterType { - kTopK // Top-k router. The initial implementation supports top-1. + kTopK // Top-k router. }; enum class MoEDispatcherType { diff --git a/infini_train/src/autograd/moe.cc b/infini_train/src/autograd/topk_mask.cc similarity index 70% rename from infini_train/src/autograd/moe.cc rename to infini_train/src/autograd/topk_mask.cc index 05134e82..16dc6629 100644 --- a/infini_train/src/autograd/moe.cc +++ b/infini_train/src/autograd/topk_mask.cc @@ -1,4 +1,4 @@ -#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/autograd/topk_mask.h" #include "glog/logging.h" @@ -7,25 +7,26 @@ namespace infini_train::autograd { -std::vector> Top1Mask::Forward(const std::vector> &input_tensors) { +std::vector> TopKMask::Forward(const std::vector> &input_tensors) { CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); const auto &input = input_tensors[0]; auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "Top1MaskForward"}, input)}; + return {Dispatcher::Instance().Call>({device, "TopKMaskForward"}, input, topk_)}; } -void Top1Mask::SetupContext(const std::vector> &, +void TopKMask::SetupContext(const std::vector> &, const std::vector> &output_tensors) { saved_tensors_ = {output_tensors[0]}; } -std::vector> Top1Mask::Backward(const std::vector> &grad_outputs) { +std::vector> TopKMask::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; const auto &mask_values = saved_tensors_[0]; auto device = grad_output->GetDevice().type(); return { - Dispatcher::Instance().Call>({device, "Top1MaskBackward"}, grad_output, mask_values)}; + Dispatcher::Instance().Call>({device, "TopKMaskBackward"}, grad_output, mask_values)}; } } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/top1_mask.cc b/infini_train/src/kernels/cpu/topk_mask.cc similarity index 50% rename from infini_train/src/kernels/cpu/top1_mask.cc rename to infini_train/src/kernels/cpu/topk_mask.cc index d6ae91d6..6a7191b9 100644 --- a/infini_train/src/kernels/cpu/top1_mask.cc +++ b/infini_train/src/kernels/cpu/topk_mask.cc @@ -1,4 +1,6 @@ +#include #include +#include #include "glog/logging.h" @@ -7,13 +9,15 @@ namespace infini_train::kernels::cpu { -std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { - CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only"; +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only"; CHECK_GE(input->Dims().size(), 1); const auto &dims = input->Dims(); const int64_t num_experts = dims.back(); CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); const int64_t rows = input->NumElements() / num_experts; auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); @@ -22,24 +26,41 @@ std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { const float *in = static_cast(input->DataPtr()); float *out = static_cast(output->DataPtr()); for (int64_t row = 0; row < rows; ++row) { - int64_t best_idx = 0; - float best_value = in[row * num_experts]; - for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { - const float value = in[row * num_experts + expert_idx]; - if (value > best_value) { - best_value = value; - best_idx = expert_idx; + const int64_t row_offset = row * num_experts; + std::vector selected_experts(num_experts, false); + float selected_sum = 0.0f; + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value = -std::numeric_limits::infinity(); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (selected_experts[expert_idx]) { + continue; + } + const float value = in[row_offset + expert_idx]; + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + CHECK_GE(best_idx, 0); + selected_experts[best_idx] = true; + out[row_offset + best_idx] = best_value; + selected_sum += best_value; + } + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + out[row_offset + expert_idx] + = out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum; } } - out[row * num_experts + best_idx] = best_value; } return output; } -std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, const std::shared_ptr &mask_values) { - CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only"; + CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only"; CHECK(mask_values->Dtype() == DataType::kFLOAT32); CHECK(grad_output->Dims() == mask_values->Dims()); @@ -58,10 +79,10 @@ std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_out } // namespace infini_train::kernels::cpu -#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \ +#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) -REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward) -REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward) +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward) -#undef REGISTER_CPU_TOP1_MASK_KERNEL +#undef REGISTER_CPU_TOPK_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/top1_mask.cu b/infini_train/src/kernels/cuda/topk_mask.cu similarity index 66% rename from infini_train/src/kernels/cuda/top1_mask.cu rename to infini_train/src/kernels/cuda/topk_mask.cu index 8fd00c91..e38c793e 100644 --- a/infini_train/src/kernels/cuda/top1_mask.cu +++ b/infini_train/src/kernels/cuda/topk_mask.cu @@ -11,33 +11,44 @@ namespace infini_train::kernels::cuda { template -__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, - int64_t num_experts) { +__global__ void TopKMaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, + int64_t num_experts, int64_t topk) { int64_t row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= rows) { return; } const int64_t offset = row * num_experts; - int64_t best_idx = 0; - float best_value = static_cast(input[offset]); - for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + float selected_sum = 0.0f; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const float value = static_cast(input[offset + expert_idx]); - if (value > best_value) { - best_value = value; - best_idx = expert_idx; + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < num_experts; ++other_idx) { + const float other_value = static_cast(input[offset + other_idx]); + if (other_value > value || (other_value == value && other_idx < expert_idx)) { + ++rank; + } } + const bool selected = rank < topk; + output[offset + expert_idx] = selected ? input[offset + expert_idx] : T(0.0f); + selected_sum += selected ? value : 0.0f; } - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f); + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (static_cast(output[offset + expert_idx]) != 0.0f) { + output[offset + expert_idx] = T(static_cast(output[offset + expert_idx]) / selected_sum); + } + } } } -std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { CHECK_GE(input->Dims().size(), 1); const auto &dims = input->Dims(); const int64_t num_experts = dims.back(); CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); const int64_t rows = input->NumElements() / num_experts; auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); @@ -52,16 +63,16 @@ std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { - Top1MaskForwardKernel<<>>( - static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts); + TopKMaskForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts, topk); }, - "CUDA Top1MaskForward"); + "CUDA TopKMaskForward"); return output; } template -__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, +__global__ void TopKMaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, T *__restrict__ grad_input, int64_t total_elements) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_elements) { @@ -70,7 +81,7 @@ __global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); } -std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, const std::shared_ptr &mask_values) { CHECK(grad_output->Dims() == mask_values->Dims()); CHECK(grad_output->Dtype() == mask_values->Dtype()); @@ -87,21 +98,21 @@ std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_out core::cuda::DispatchCudaFunc( grad_output->Dtype(), [=]() { - Top1MaskBackwardKernel<<>>( + TopKMaskBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), static_cast(grad_input->DataPtr()), total_elements); }, - "CUDA Top1MaskBackward"); + "CUDA TopKMaskBackward"); return grad_input; } } // namespace infini_train::kernels::cuda -#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \ +#define REGISTER_CUDA_TOPK_MASK_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) -REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward) -REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward) +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskBackward) -#undef REGISTER_CUDA_TOP1_MASK_KERNEL +#undef REGISTER_CUDA_TOPK_MASK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc index 59dec209..851c57be 100644 --- a/infini_train/src/nn/modules/transformer/moe/router.cc +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -6,7 +6,7 @@ #include "glog/logging.h" #include "infini_train/include/autograd/linear.h" -#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/autograd/topk_mask.h" #include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" @@ -17,8 +17,9 @@ namespace infini_train::nn::moe { TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); CHECK(moe_config.router_type == MoERouterType::kTopK); - CHECK_EQ(moe_config.router_topk, 1) << "Current InfiniTrain MoE implementation supports top-1 routing only"; CHECK_GT(moe_config.num_experts, 0); + CHECK_GT(moe_config.router_topk, 0); + CHECK_LE(moe_config.router_topk, moe_config.num_experts); parameters_[kParamWeightName] = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, @@ -43,7 +44,8 @@ std::vector> TopKRouter::Forward(const std::vector()->Apply(linear_inputs)[0]; auto scores = function::Softmax(logits, -1); - auto routing_probs = std::make_shared()->Apply({scores})[0]; + const auto &moe_config = RequireMoEConfig(config_); + auto routing_probs = std::make_shared(moe_config.router_topk)->Apply({scores})[0]; return {routing_probs}; } diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index da3dd70e..469ff386 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -527,10 +527,10 @@ void TestStateDict() { } // ============================================================================ -// Test 11: MoE Layer MVP +// Test 11: MoE Layer // ============================================================================ void TestMoELayer() { - std::cout << "\n=== Test 11: MoE Layer MVP ===" << std::endl; + std::cout << "\n=== Test 11: MoE Layer ===" << std::endl; nn::TransformerConfig config; config.n_embd = 32; @@ -543,29 +543,43 @@ void TestMoELayer() { config.moe_config->num_experts = 2; config.moe_config->router_topk = 1; - try { - auto moe = std::make_shared(config); - auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); - input->Uniform(); + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); - auto output = (*moe)({input}); - if (output.size() != 1) { - std::cout << "FAIL: MoELayer forward should return 1 tensor" << std::endl; - return; - } - if (output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: MoELayer output shape mismatch" << std::endl; - return; - } + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); - auto params = moe->Parameters(); - if (params.empty()) { - std::cout << "FAIL: MoELayer should own router and expert parameters" << std::endl; - return; - } + auto params = moe->Parameters(); + CHECK(!params.empty()); - std::cout << "SUCCESS: MoE layer MVP forward works correctly!" << std::endl; - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } + std::cout << "SUCCESS: MoE layer forward works correctly!" << std::endl; +} + +void TestMoELayerTop2() { + std::cout << "\n=== Test 12: MoE Layer Top-2 ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); + + std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; } // ============================================================================ @@ -591,6 +605,7 @@ int main(int argc, char *argv[]) { TestRopeUtils(); TestStateDict(); TestMoELayer(); + TestMoELayerTop2(); std::cout << "\n========================================" << std::endl; std::cout << " All Tests Completed" << std::endl; From 64f61d55c35d2c1f02d699479a3dbd55c1d8710b Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 13 May 2026 07:36:52 +0000 Subject: [PATCH 3/9] feat: support moe_ffn_hidden_size config --- infini_train/src/nn/modules/transformer/mlp.cc | 6 ++++++ test/transformer/test_transformer_architecture.cc | 13 +++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/infini_train/src/nn/modules/transformer/mlp.cc b/infini_train/src/nn/modules/transformer/mlp.cc index 9f1f488c..ac35d144 100644 --- a/infini_train/src/nn/modules/transformer/mlp.cc +++ b/infini_train/src/nn/modules/transformer/mlp.cc @@ -37,6 +37,12 @@ MLP::MLP(const TransformerConfig &config) : CloneableModule(kType) { // Round up to multiple_of ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of; + if (config.ffn_type == FFNType::kMoE && config.moe_config.has_value() + && config.moe_config->moe_ffn_hidden_size > 0) { + ffn_hidden = config.moe_config->moe_ffn_hidden_size; + } + CHECK_GT(ffn_hidden, 0); + // c_fc: ColumnParallel (input full, output parallel) modules_[kCFcLayerName] = std::make_shared( /*in_features=*/config.n_embd, /*out_features=*/ffn_hidden, diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index 469ff386..42efda2d 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -564,12 +564,13 @@ void TestMoELayerTop2() { config.n_embd = 32; config.n_head = 2; config.n_kv_head = 2; - config.activation_type = nn::MLPType::kGELU; - config.add_bias_linear = true; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; config.ffn_type = nn::FFNType::kMoE; config.moe_config = nn::MoEConfig{}; config.moe_config->num_experts = 4; config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; auto moe = std::make_shared(config); auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); @@ -579,6 +580,14 @@ void TestMoELayerTop2() { CHECK_EQ(output.size(), 1); CHECK(output[0]->Dims() == input->Dims()); + auto state = moe->StateDict(); + CHECK(state.contains("experts.expert_0.c_fc.weight")); + CHECK(state.contains("experts.expert_0.c_fc2.weight")); + CHECK(state.contains("experts.expert_0.c_proj.weight")); + CHECK(state.at("experts.expert_0.c_fc.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_fc2.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_proj.weight")->Dims() == std::vector({config.n_embd, 48})); + std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; } From 50cb5fd7d8c509ff9af18ca87de2b3f5d90e5a7b Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 26 May 2026 12:24:26 +0000 Subject: [PATCH 4/9] test: migrate test_transformer_architecture to ctest framework --- .../test_transformer_architecture.cc | 624 ------------------ .../test_transformer_architecture.cc | 159 +++++ 2 files changed, 159 insertions(+), 624 deletions(-) delete mode 100644 test/transformer/test_transformer_architecture.cc diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc deleted file mode 100644 index 42efda2d..00000000 --- a/test/transformer/test_transformer_architecture.cc +++ /dev/null @@ -1,624 +0,0 @@ -#include -#include -#include - -#include "glog/logging.h" - -#include "example/gpt2/config.h" -#include "example/llama3/config.h" -#include "infini_train/include/nn/modules/activations.h" -#include "infini_train/include/nn/modules/normalization.h" -#include "infini_train/include/nn/modules/sparse.h" -#include "infini_train/include/nn/modules/transformer/causal_self_attention.h" -#include "infini_train/include/nn/modules/transformer/mlp.h" -#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" -#include "infini_train/include/nn/modules/transformer/transformer.h" -#include "infini_train/include/nn/modules/transformer/transformer_config.h" -#include "infini_train/include/nn/modules/transformer/utils.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/tensor.h" - -using namespace infini_train; -namespace nn = infini_train::nn; - -// ============================================================================ -// Test 1: TransformerConfig Validation -// ============================================================================ -void TestConfigValidation() { - std::cout << "\n=== Test 1: TransformerConfig Validation ===" << std::endl; - - bool all_passed = true; - - // Test GPT2 config - auto gpt2_config = gpt2::GPT2Config(); - if (gpt2_config.attention_type != nn::AttentionType::kStandard) { - std::cout << "FAIL: GPT2 config should use Standard attention" << std::endl; - all_passed = false; - } - if (gpt2_config.activation_type != nn::MLPType::kGELU) { - std::cout << "FAIL: GPT2 config should use GELU activation" << std::endl; - all_passed = false; - } - if (gpt2_config.norm_type != nn::NormType::kLayerNorm) { - std::cout << "FAIL: GPT2 config should use LayerNorm" << std::endl; - all_passed = false; - } - if (!gpt2_config.add_bias_linear) { - std::cout << "FAIL: GPT2 config should have bias enabled" << std::endl; - all_passed = false; - } - if (!gpt2_config.tie_weights) { - std::cout << "FAIL: GPT2 config should have tied weights" << std::endl; - all_passed = false; - } - - // Test LLaMA3 config - auto llama3_config = llama3::LLaMA3Config(); - if (llama3_config.attention_type != nn::AttentionType::kRoPE) { - std::cout << "FAIL: LLaMA3 config should use RoPE attention" << std::endl; - all_passed = false; - } - if (llama3_config.activation_type != nn::MLPType::kSwiGLU) { - std::cout << "FAIL: LLaMA3 config should use SwiGLU activation" << std::endl; - all_passed = false; - } - if (llama3_config.norm_type != nn::NormType::kRMSNorm) { - std::cout << "FAIL: LLaMA3 config should use RMSNorm" << std::endl; - all_passed = false; - } - if (llama3_config.add_bias_linear) { - std::cout << "FAIL: LLaMA3 config should have bias disabled" << std::endl; - all_passed = false; - } - if (llama3_config.tie_weights) { - std::cout << "FAIL: LLaMA3 config should not have tied weights" << std::endl; - all_passed = false; - } - - // Test GQA detection - if (!llama3_config.UseGQA()) { - std::cout << "FAIL: LLaMA3 config should detect GQA (n_kv_head < n_head)" << std::endl; - all_passed = false; - } - if (gpt2_config.UseGQA()) { - std::cout << "FAIL: GPT2 config should not detect GQA (n_kv_head == n_head)" << std::endl; - all_passed = false; - } - - if (all_passed) { - std::cout << "SUCCESS: All config validations passed!" << std::endl; - } -} - -// ============================================================================ -// Test 2: Embedding Layer -// ============================================================================ -void TestEmbedding() { - std::cout << "\n=== Test 2: Embedding Layer ===" << std::endl; - - const int64_t vocab_size = 1000; - const int64_t embedding_dim = 128; - const int64_t batch_size = 2; - const int64_t seq_len = 16; - - try { - auto embedding = std::make_shared(vocab_size, embedding_dim); - - // Check parameters - auto params = embedding->Parameters(); - if (params.size() != 1) { - std::cout << "FAIL: Embedding should have 1 parameter, got " << params.size() << std::endl; - return; - } - - // Check weight shape - auto weight = embedding->parameter(nn::Embedding::kParamWeightName); - if (weight->Dims() != std::vector{vocab_size, embedding_dim}) { - std::cout << "FAIL: Embedding weight shape mismatch" << std::endl; - return; - } - - // Forward pass - auto input = std::make_shared(std::vector{batch_size, seq_len}, DataType::kINT64); - auto output = (*embedding)({input}); - - if (output.size() != 1) { - std::cout << "FAIL: Embedding forward should return 1 tensor" << std::endl; - return; - } - - const auto &out_dims = output[0]->Dims(); - if (out_dims != std::vector{batch_size, seq_len, embedding_dim}) { - std::cout << "FAIL: Embedding output shape mismatch. Expected [" << batch_size << ", " << seq_len << ", " - << embedding_dim << "], got [" << out_dims[0] << ", " << out_dims[1] << ", " << out_dims[2] << "]" - << std::endl; - return; - } - - std::cout << "SUCCESS: Embedding layer works correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 3: Normalization Layers (LayerNorm vs RMSNorm) -// ============================================================================ -void TestNormalization() { - std::cout << "\n=== Test 3: Normalization Layers ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - - try { - // Test LayerNorm - auto layernorm = std::make_shared(std::vector{hidden_size}); - auto ln_params = layernorm->Parameters(); - if (ln_params.size() != 2) { - std::cout << "FAIL: LayerNorm should have 2 parameters (weight, bias), got " << ln_params.size() - << std::endl; - return; - } - - // Test RMSNorm - auto rmsnorm = std::make_shared(hidden_size); - auto rms_params = rmsnorm->Parameters(); - if (rms_params.size() != 1) { - std::cout << "FAIL: RMSNorm should have 1 parameter (weight), got " << rms_params.size() << std::endl; - return; - } - - // Forward pass for both - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto ln_output = (*layernorm)({input}); - auto rms_output = (*rmsnorm)({input}); - - if (ln_output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: LayerNorm output shape mismatch" << std::endl; - return; - } - - if (rms_output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: RMSNorm output shape mismatch" << std::endl; - return; - } - - std::cout << "SUCCESS: Normalization layers work correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 4: MLP Layer (GELU vs SwiGLU) -// ============================================================================ -void TestMlp() { - std::cout << "\n=== Test 4: MLP Layer ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - - try { - // Test GPT2-style MLP (GELU) - nn::TransformerConfig gpt2_mlp_config; - gpt2_mlp_config.n_embd = hidden_size; - gpt2_mlp_config.activation_type = nn::MLPType::kGELU; - gpt2_mlp_config.ffn_expansion_ratio = 4.0f; - gpt2_mlp_config.add_bias_linear = true; - - auto gpt2_mlp = std::make_shared(gpt2_mlp_config); - auto gpt2_params = gpt2_mlp->Parameters(); - - // GPT2 MLP should have: c_fc.weight, c_fc.bias, c_proj.weight, c_proj.bias - if (gpt2_params.size() != 4) { - std::cout << "FAIL: GPT2 MLP should have 4 parameters, got " << gpt2_params.size() << std::endl; - return; - } - - // Test LLaMA3-style MLP (SwiGLU) - nn::TransformerConfig llama3_mlp_config; - llama3_mlp_config.n_embd = hidden_size; - llama3_mlp_config.activation_type = nn::MLPType::kSwiGLU; - llama3_mlp_config.ffn_expansion_ratio = 4.0f; - llama3_mlp_config.add_bias_linear = false; - llama3_mlp_config.ffn_dim_multiplier = 1.5f; - llama3_mlp_config.multiple_of = 256; - - auto llama3_mlp = std::make_shared(llama3_mlp_config); - auto llama3_params = llama3_mlp->Parameters(); - - // LLaMA3 MLP should have: c_fc.weight, c_fc2.weight, c_proj.weight (no bias) - if (llama3_params.size() != 3) { - std::cout << "FAIL: LLaMA3 MLP should have 3 parameters, got " << llama3_params.size() << std::endl; - return; - } - - // Forward pass - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto gpt2_output = (*gpt2_mlp)({input}); - auto llama3_output = (*llama3_mlp)({input}); - - // Output should have same hidden dimension - if (gpt2_output[0]->Dims()[2] != hidden_size) { - std::cout << "FAIL: GPT2 MLP output hidden dim mismatch" << std::endl; - return; - } - - if (llama3_output[0]->Dims()[2] != hidden_size) { - std::cout << "FAIL: LLaMA3 MLP output hidden dim mismatch" << std::endl; - return; - } - - std::cout << "SUCCESS: MLP layers work correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 5: CausalSelfAttention -// ============================================================================ -void TestAttention() { - std::cout << "\n=== Test 5: CausalSelfAttention ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - const int64_t n_head = 4; - - try { - // Test standard attention (GPT2-style) - nn::TransformerConfig standard_config; - standard_config.n_embd = hidden_size; - standard_config.n_head = n_head; - standard_config.n_kv_head = n_head; - standard_config.attention_type = nn::AttentionType::kStandard; - standard_config.add_bias_linear = true; - - auto standard_attn = std::make_shared(standard_config); - auto standard_params = standard_attn->Parameters(); - - // Should have c_attn (QKV combined) and c_proj with biases - if (standard_params.size() != 4) { - std::cout << "FAIL: Standard attention should have 4 parameters, got " << standard_params.size() - << std::endl; - return; - } - - // Test RoPE attention with GQA (LLaMA3-style) - nn::TransformerConfig rope_config; - rope_config.n_embd = hidden_size; - rope_config.n_head = n_head; - rope_config.n_kv_head = 2; // GQA: fewer KV heads - rope_config.attention_type = nn::AttentionType::kRoPE; - rope_config.add_bias_linear = false; - - auto rope_attn = std::make_shared(rope_config); - auto rope_params = rope_attn->Parameters(); - - // RoPE attention without bias should have fewer params - if (rope_params.empty()) { - std::cout << "FAIL: RoPE attention should have parameters" << std::endl; - return; - } - - // Forward pass - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto standard_output = (*standard_attn)({input}); - if (standard_output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: Standard attention output shape mismatch" << std::endl; - return; - } - - std::cout << "SUCCESS: CausalSelfAttention works correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 6: TransformerLayer -// ============================================================================ -void TestTransformerLayer() { - std::cout << "\n=== Test 6: TransformerLayer ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - - try { - // Test GPT2-style layer - auto gpt2_config = gpt2::GPT2Config(); - gpt2_config.n_embd = hidden_size; - gpt2_config.n_head = 4; - gpt2_config.n_layer = 1; - - auto gpt2_layer = std::make_shared(gpt2_config); - auto gpt2_params = gpt2_layer->Parameters(); - - if (gpt2_params.empty()) { - std::cout << "FAIL: GPT2 TransformerLayer should have parameters" << std::endl; - return; - } - - // Forward pass - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto output = (*gpt2_layer)({input}); - if (output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: TransformerLayer output shape mismatch" << std::endl; - return; - } - - // Test LLaMA3-style layer - auto llama3_config = llama3::LLaMA3Config(); - llama3_config.n_embd = hidden_size; - llama3_config.n_head = 4; - llama3_config.n_kv_head = 2; - llama3_config.n_layer = 1; - - auto llama3_layer = std::make_shared(llama3_config); - auto llama3_params = llama3_layer->Parameters(); - - if (llama3_params.empty()) { - std::cout << "FAIL: LLaMA3 TransformerLayer should have parameters" << std::endl; - return; - } - - std::cout << "SUCCESS: TransformerLayer works correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 7: TransformerModel Instantiation (GPT2) -// ============================================================================ -void TestGpt2Model() { - std::cout << "\n=== Test 7: GPT2 Model Instantiation ===" << std::endl; - - auto config = gpt2::GPT2Config(); - // Use smaller config for faster testing - config.n_layer = 2; - config.n_head = 4; - config.n_embd = 64; - - try { - auto model = std::make_shared(config); - - if (model == nullptr) { - std::cout << "FAIL: Failed to create GPT2 model" << std::endl; - return; - } - - auto params = model->Parameters(); - if (params.empty()) { - std::cout << "FAIL: GPT2 model has no parameters" << std::endl; - return; - } - - std::cout << "SUCCESS: GPT2 model created with " << params.size() << " parameters!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 8: TransformerModel Instantiation (LLaMA3) -// ============================================================================ -void TestLlama3Model() { - std::cout << "\n=== Test 8: LLaMA3 Model Instantiation ===" << std::endl; - - auto config = llama3::LLaMA3Config(); - // Use smaller config for faster testing - config.n_layer = 2; - config.n_head = 4; - config.n_kv_head = 2; - config.n_embd = 64; - - try { - auto model = std::make_shared(config); - - if (model == nullptr) { - std::cout << "FAIL: Failed to create LLaMA3 model" << std::endl; - return; - } - - auto params = model->Parameters(); - if (params.empty()) { - std::cout << "FAIL: LLaMA3 model has no parameters" << std::endl; - return; - } - - std::cout << "SUCCESS: LLaMA3 model created with " << params.size() << " parameters!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 9: RoPE Utilities -// ============================================================================ -void TestRopeUtils() { - std::cout << "\n=== Test 9: RoPE Utilities ===" << std::endl; - - const int64_t head_dim = 64; - const int64_t seq_len = 128; - - try { - // Test precompute freqs_cis - auto freqs_cis = PrecomputeFreqsCis(head_dim, seq_len); - - // freqs_cis shape: [seq_len, head_dim/2, 2] (cos and sin stacked on last dim) - const auto &dims = freqs_cis->Dims(); - if (dims.size() != 3) { - std::cout << "FAIL: freqs_cis should be 3D, got " << dims.size() << "D" << std::endl; - return; - } - if (dims[0] != seq_len) { - std::cout << "FAIL: freqs_cis seq_len mismatch. Expected " << seq_len << ", got " << dims[0] << std::endl; - return; - } - if (dims[1] != head_dim / 2) { - std::cout << "FAIL: freqs_cis head_dim/2 mismatch. Expected " << head_dim / 2 << ", got " << dims[1] - << std::endl; - return; - } - if (dims[2] != 2) { - std::cout << "FAIL: freqs_cis last dim should be 2 (cos, sin), got " << dims[2] << std::endl; - return; - } - - std::cout << "SUCCESS: RoPE utilities work correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 10: Model StateDict -// ============================================================================ -void TestStateDict() { - std::cout << "\n=== Test 10: Model StateDict ===" << std::endl; - - nn::TransformerConfig config; - config.n_layer = 1; - config.n_head = 2; - config.n_kv_head = 2; // Must set explicitly - config.n_embd = 32; - config.vocab_size = 1000; - config.attention_type = nn::AttentionType::kStandard; - config.activation_type = nn::MLPType::kGELU; - config.norm_type = nn::NormType::kLayerNorm; - config.add_bias_linear = true; - - try { - auto model = std::make_shared(config); - auto state_dict = model->StateDict(); - - if (state_dict.empty()) { - std::cout << "FAIL: StateDict should not be empty" << std::endl; - return; - } - - // StateDict includes both parameters and buffers, so it should have >= parameters count - auto params = model->Parameters(); - auto buffers = model->Buffers(); - - if (state_dict.size() < params.size()) { - std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should be >= parameter count (" - << params.size() << ")" << std::endl; - return; - } - - // Expected: state_dict.size() == params.size() + buffers.size() - size_t expected_size = params.size() + buffers.size(); - if (state_dict.size() != expected_size) { - std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should equal params (" << params.size() - << ") + buffers (" << buffers.size() << ") = " << expected_size << std::endl; - return; - } - - std::cout << "SUCCESS: StateDict works correctly with " << state_dict.size() << " entries (" << params.size() - << " params + " << buffers.size() << " buffers)!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 11: MoE Layer -// ============================================================================ -void TestMoELayer() { - std::cout << "\n=== Test 11: MoE Layer ===" << std::endl; - - nn::TransformerConfig config; - config.n_embd = 32; - config.n_head = 2; - config.n_kv_head = 2; - config.activation_type = nn::MLPType::kGELU; - config.add_bias_linear = true; - config.ffn_type = nn::FFNType::kMoE; - config.moe_config = nn::MoEConfig{}; - config.moe_config->num_experts = 2; - config.moe_config->router_topk = 1; - - auto moe = std::make_shared(config); - auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); - input->Uniform(); - - auto output = (*moe)({input}); - CHECK_EQ(output.size(), 1); - CHECK(output[0]->Dims() == input->Dims()); - - auto params = moe->Parameters(); - CHECK(!params.empty()); - - std::cout << "SUCCESS: MoE layer forward works correctly!" << std::endl; -} - -void TestMoELayerTop2() { - std::cout << "\n=== Test 12: MoE Layer Top-2 ===" << std::endl; - - nn::TransformerConfig config; - config.n_embd = 32; - config.n_head = 2; - config.n_kv_head = 2; - config.activation_type = nn::MLPType::kSwiGLU; - config.add_bias_linear = false; - config.ffn_type = nn::FFNType::kMoE; - config.moe_config = nn::MoEConfig{}; - config.moe_config->num_experts = 4; - config.moe_config->router_topk = 2; - config.moe_config->moe_ffn_hidden_size = 48; - - auto moe = std::make_shared(config); - auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); - input->Uniform(); - - auto output = (*moe)({input}); - CHECK_EQ(output.size(), 1); - CHECK(output[0]->Dims() == input->Dims()); - - auto state = moe->StateDict(); - CHECK(state.contains("experts.expert_0.c_fc.weight")); - CHECK(state.contains("experts.expert_0.c_fc2.weight")); - CHECK(state.contains("experts.expert_0.c_proj.weight")); - CHECK(state.at("experts.expert_0.c_fc.weight")->Dims() == std::vector({48, config.n_embd})); - CHECK(state.at("experts.expert_0.c_fc2.weight")->Dims() == std::vector({48, config.n_embd})); - CHECK(state.at("experts.expert_0.c_proj.weight")->Dims() == std::vector({config.n_embd, 48})); - - std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; -} - -// ============================================================================ -// Main -// ============================================================================ -int main(int argc, char *argv[]) { - google::InitGoogleLogging(argv[0]); - - nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); - - std::cout << "========================================" << std::endl; - std::cout << " Transformer architecture Tests" << std::endl; - std::cout << "========================================" << std::endl; - - TestConfigValidation(); - TestEmbedding(); - TestNormalization(); - TestMlp(); - TestAttention(); - TestTransformerLayer(); - TestGpt2Model(); - TestLlama3Model(); - TestRopeUtils(); - TestStateDict(); - TestMoELayer(); - TestMoELayerTop2(); - - std::cout << "\n========================================" << std::endl; - std::cout << " All Tests Completed" << std::endl; - std::cout << "========================================" << std::endl; - - return 0; -} diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index ba62e1e3..ad7a9da3 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -4,10 +4,13 @@ #include "gtest/gtest.h" +#include "infini_train/include/autograd/topk.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" @@ -189,4 +192,160 @@ TEST_P(TransformerModuleTest, StateDict) { EXPECT_GE(state_dict.size(), params.size()); } + +TEST_P(TransformerModuleTest, MoELayerTop1) { + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 2; + config.moe_config->router_topk = 1; + config.moe_config->router_pre_softmax = true; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + EXPECT_FALSE(moe->Parameters().empty()); +} + +TEST_P(TransformerModuleTest, MoELayerTop2SwiGLU) { + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto state = moe->StateDict(); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc2.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_proj.weight")); + EXPECT_EQ(state.at("experts.expert_0.c_fc.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_fc2.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_proj.weight")->Dims(), (std::vector{config.n_embd, 48})); +} + +TEST_P(TransformerModuleTest, TopKRouterMegatronOutputs) { + nn::TransformerConfig config; + config.n_embd = 32; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + + auto router = std::make_shared(config); + router->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*router)({input}); + ASSERT_EQ(output.size(), 2); + EXPECT_EQ(output[0]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[1]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[0]->Dtype(), DataType::kFLOAT32); + EXPECT_EQ(output[1]->Dtype(), DataType::kBOOL); +} + +TEST_P(TransformerModuleTest, TopKTorchInterface) { + ONLY_CPU(); + const float data[] = {1.0f, 5.0f, 2.0f, 4.0f, 3.0f, 0.0f}; + auto input = std::make_shared(data, std::vector{2, 3}, DataType::kFLOAT32); + + auto largest_topk = std::make_shared(2, 1, true, true); + auto largest_values = largest_topk->Apply({input})[0]; + auto largest_indices = largest_topk->TopIndices(); + ASSERT_EQ(largest_values->Dims(), (std::vector{2, 2})); + ASSERT_EQ(largest_indices->Dims(), (std::vector{2, 2})); + const auto *largest_values_ptr = static_cast(largest_values->DataPtr()); + const auto *largest_indices_ptr = static_cast(largest_indices->DataPtr()); + EXPECT_FLOAT_EQ(largest_values_ptr[0], 5.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[1], 2.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[2], 4.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[3], 3.0f); + EXPECT_EQ(largest_indices_ptr[0], 1); + EXPECT_EQ(largest_indices_ptr[1], 2); + EXPECT_EQ(largest_indices_ptr[2], 0); + EXPECT_EQ(largest_indices_ptr[3], 1); + + auto smallest_topk = std::make_shared(1, 0, false, true); + auto smallest_values = smallest_topk->Apply({input})[0]; + auto smallest_indices = smallest_topk->TopIndices(); + ASSERT_EQ(smallest_values->Dims(), (std::vector{1, 3})); + ASSERT_EQ(smallest_indices->Dims(), (std::vector{1, 3})); + const auto *smallest_values_ptr = static_cast(smallest_values->DataPtr()); + const auto *smallest_indices_ptr = static_cast(smallest_indices->DataPtr()); + EXPECT_FLOAT_EQ(smallest_values_ptr[0], 1.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[1], 3.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[2], 0.0f); + EXPECT_EQ(smallest_indices_ptr[0], 0); + EXPECT_EQ(smallest_indices_ptr[1], 1); + EXPECT_EQ(smallest_indices_ptr[2], 1); +} + +TEST_P(TransformerModuleTest, TopKRouterNormalization) { + ONLY_CPU(); + auto make_router = [](nn::MoEConfig::RouterScoreFunction score_function, bool pre_softmax) { + nn::TransformerConfig config; + config.n_embd = 2; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 3; + config.moe_config->router_topk = 2; + config.moe_config->router_score_function = score_function; + config.moe_config->router_pre_softmax = pre_softmax; + auto router = std::make_shared(config); + auto weight = router->parameter(nn::moe::TopKRouter::kParamWeightName); + auto *weight_ptr = static_cast(weight->DataPtr()); + weight_ptr[0] = 1.0f; + weight_ptr[1] = 0.0f; + weight_ptr[2] = 2.0f; + weight_ptr[3] = 0.0f; + weight_ptr[4] = 0.0f; + weight_ptr[5] = 0.0f; + return router; + }; + + const float input_data[] = {1.0f, 1.0f}; + auto input = std::make_shared(input_data, std::vector{1, 1, 2}, DataType::kFLOAT32); + + auto softmax_router = make_router(nn::MoEConfig::RouterScoreFunction::kSoftmax, false); + auto softmax_output = (*softmax_router)({input}); + const auto *softmax_probs = static_cast(softmax_output[0]->DataPtr()); + EXPECT_NEAR(softmax_probs[0] + softmax_probs[1] + softmax_probs[2], 1.0f, 1e-5f); + EXPECT_GT(softmax_probs[1], softmax_probs[0]); + EXPECT_FLOAT_EQ(softmax_probs[2], 0.0f); + + auto sigmoid_router = make_router(nn::MoEConfig::RouterScoreFunction::kSigmoid, true); + auto sigmoid_output = (*sigmoid_router)({input}); + const auto *sigmoid_probs = static_cast(sigmoid_output[0]->DataPtr()); + EXPECT_NEAR(sigmoid_probs[0] + sigmoid_probs[1] + sigmoid_probs[2], 1.0f, 1e-5f); + EXPECT_GT(sigmoid_probs[1], sigmoid_probs[0]); + EXPECT_FLOAT_EQ(sigmoid_probs[2], 0.0f); +} + INFINI_TRAIN_REGISTER_TEST(TransformerModuleTest); From 36565152fcf53883aca7f32fcb02910fdc83eae6 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 27 May 2026 02:51:20 +0000 Subject: [PATCH 5/9] refactor: rename topk_mask to topk and align with torch.topk API --- infini_train/include/autograd/topk.h | 40 ++++++ infini_train/include/autograd/topk_mask.h | 29 ---- infini_train/src/autograd/topk.cc | 39 ++++++ infini_train/src/autograd/topk_mask.cc | 32 ----- infini_train/src/kernels/cpu/topk.cc | 124 +++++++++++++++++ infini_train/src/kernels/cpu/topk_mask.cc | 88 ------------ infini_train/src/kernels/cuda/topk.cu | 155 +++++++++++++++++++++ infini_train/src/kernels/cuda/topk_mask.cu | 118 ---------------- 8 files changed, 358 insertions(+), 267 deletions(-) create mode 100644 infini_train/include/autograd/topk.h delete mode 100644 infini_train/include/autograd/topk_mask.h create mode 100644 infini_train/src/autograd/topk.cc delete mode 100644 infini_train/src/autograd/topk_mask.cc create mode 100644 infini_train/src/kernels/cpu/topk.cc delete mode 100644 infini_train/src/kernels/cpu/topk_mask.cc create mode 100644 infini_train/src/kernels/cuda/topk.cu delete mode 100644 infini_train/src/kernels/cuda/topk_mask.cu diff --git a/infini_train/include/autograd/topk.h b/infini_train/include/autograd/topk.h new file mode 100644 index 00000000..7752efca --- /dev/null +++ b/infini_train/include/autograd/topk.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +// FIXME(dcj): Align this API with torch.topk and return both values and indices from Forward once +// InfiniTrain autograd supports marking individual outputs as non-differentiable. Today indices +// are exposed through TopIndices() to avoid waiting for gradients on metadata outputs. +class TopK : public Function { +public: + static constexpr char kType[] = "TopKFunction"; + + explicit TopK(int64_t topk, int64_t dim = -1, bool largest = true, bool sorted = true) + : Function(kType), topk_(topk), dim_(dim), largest_(largest), sorted_(sorted) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + + std::shared_ptr TopIndices() const; + +private: + int64_t topk_ = 1; + int64_t dim_ = -1; + bool largest_ = true; + bool sorted_ = true; + std::shared_ptr top_indices_; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/topk_mask.h b/infini_train/include/autograd/topk_mask.h deleted file mode 100644 index 355ef400..00000000 --- a/infini_train/include/autograd/topk_mask.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include - -#include "infini_train/include/autograd/function.h" - -namespace infini_train { -class Tensor; -} - -namespace infini_train::autograd { - -class TopKMask : public Function { -public: - static constexpr char kType[] = "TopKMaskFunction"; - - explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - int64_t topk_ = 1; -}; - -} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/topk.cc b/infini_train/src/autograd/topk.cc new file mode 100644 index 00000000..4e0420b8 --- /dev/null +++ b/infini_train/src/autograd/topk.cc @@ -0,0 +1,39 @@ +#include "infini_train/include/autograd/topk.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> TopK::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + auto topk_outputs = Dispatcher::Instance().Call>>( + {device, "TopKForward"}, input, topk_, dim_, largest_, sorted_); + CHECK_EQ(topk_outputs.size(), 2); + top_indices_ = topk_outputs[1]; + return {topk_outputs[0]}; +} + +void TopK::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + input_dims_ = input_tensors[0]->Dims(); + saved_tensors_ = {top_indices_}; +} + +std::vector> TopK::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &top_grad = grad_outputs[0]; + const auto &top_indices = saved_tensors_[0]; + auto device = top_grad->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "TopKBackward"}, top_grad, top_indices, + input_dims_, dim_)}; +} + +std::shared_ptr TopK::TopIndices() const { return top_indices_; } + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/topk_mask.cc b/infini_train/src/autograd/topk_mask.cc deleted file mode 100644 index 16dc6629..00000000 --- a/infini_train/src/autograd/topk_mask.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include "infini_train/include/autograd/topk_mask.h" - -#include "glog/logging.h" - -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -namespace infini_train::autograd { - -std::vector> TopKMask::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - CHECK_GT(topk_, 0); - const auto &input = input_tensors[0]; - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "TopKMaskForward"}, input, topk_)}; -} - -void TopKMask::SetupContext(const std::vector> &, - const std::vector> &output_tensors) { - saved_tensors_ = {output_tensors[0]}; -} - -std::vector> TopKMask::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - const auto &mask_values = saved_tensors_[0]; - auto device = grad_output->GetDevice().type(); - return { - Dispatcher::Instance().Call>({device, "TopKMaskBackward"}, grad_output, mask_values)}; -} - -} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/topk.cc b/infini_train/src/kernels/cpu/topk.cc new file mode 100644 index 00000000..9e191143 --- /dev/null +++ b/infini_train/src/kernels/cpu/topk.cc @@ -0,0 +1,124 @@ +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + const float *in = static_cast(input->DataPtr()); + float *values = static_cast(top_values->DataPtr()); + int64_t *indices = static_cast(top_indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + std::vector selected_indices(dim_size, false); + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value + = largest ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + for (int64_t idx = 0; idx < dim_size; ++idx) { + if (selected_indices[idx]) { + continue; + } + const float value = in[outer * dim_size * inner_size + idx * inner_size + inner]; + const bool better = largest ? value > best_value : value < best_value; + if (better) { + best_value = value; + best_idx = idx; + } + } + CHECK_GE(best_idx, 0); + selected_indices[best_idx] = true; + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + values[out_offset] = best_value; + indices[out_offset] = best_idx; + } + } + } + + return {top_values, top_indices}; +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + std::memset(grad_input->DataPtr(), 0, grad_input->SizeInBytes()); + + const size_t elem_size = kDataTypeToSize.at(grad_values->Dtype()); + const auto *src = static_cast(grad_values->DataPtr()); + auto *dst = static_cast(grad_input->DataPtr()); + const auto *idx_ptr = static_cast(indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = idx_ptr[out_offset]; + CHECK_GE(selected_idx, 0); + CHECK_LT(selected_idx, dim_size); + std::memcpy(dst + (outer * dim_size * inner_size + selected_idx * inner_size + inner) * elem_size, + src + out_offset * elem_size, elem_size); + } + } + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOPK_KERNEL(TopKForward) +REGISTER_CPU_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CPU_TOPK_KERNEL diff --git a/infini_train/src/kernels/cpu/topk_mask.cc b/infini_train/src/kernels/cpu/topk_mask.cc deleted file mode 100644 index 6a7191b9..00000000 --- a/infini_train/src/kernels/cpu/topk_mask.cc +++ /dev/null @@ -1,88 +0,0 @@ -#include -#include -#include - -#include "glog/logging.h" - -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -namespace infini_train::kernels::cpu { - -std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { - CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only"; - CHECK_GE(input->Dims().size(), 1); - - const auto &dims = input->Dims(); - const int64_t num_experts = dims.back(); - CHECK_GT(num_experts, 0); - CHECK_GT(topk, 0); - CHECK_LE(topk, num_experts); - const int64_t rows = input->NumElements() / num_experts; - - auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); - output->Fill(0.0f); - - const float *in = static_cast(input->DataPtr()); - float *out = static_cast(output->DataPtr()); - for (int64_t row = 0; row < rows; ++row) { - const int64_t row_offset = row * num_experts; - std::vector selected_experts(num_experts, false); - float selected_sum = 0.0f; - for (int64_t selected = 0; selected < topk; ++selected) { - int64_t best_idx = -1; - float best_value = -std::numeric_limits::infinity(); - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - if (selected_experts[expert_idx]) { - continue; - } - const float value = in[row_offset + expert_idx]; - if (value > best_value) { - best_value = value; - best_idx = expert_idx; - } - } - CHECK_GE(best_idx, 0); - selected_experts[best_idx] = true; - out[row_offset + best_idx] = best_value; - selected_sum += best_value; - } - if (topk > 1 && selected_sum != 0.0f) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - out[row_offset + expert_idx] - = out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum; - } - } - } - - return output; -} - -std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &mask_values) { - CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only"; - CHECK(mask_values->Dtype() == DataType::kFLOAT32); - CHECK(grad_output->Dims() == mask_values->Dims()); - - auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); - grad_input->Fill(0.0f); - - const float *grad = static_cast(grad_output->DataPtr()); - const float *mask = static_cast(mask_values->DataPtr()); - float *out = static_cast(grad_input->DataPtr()); - for (int64_t i = 0; i < static_cast(grad_output->NumElements()); ++i) { - out[i] = mask[i] != 0.0f ? grad[i] : 0.0f; - } - - return grad_input; -} - -} // namespace infini_train::kernels::cpu - -#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) - -REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward) -REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward) - -#undef REGISTER_CPU_TOPK_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/topk.cu b/infini_train/src/kernels/cuda/topk.cu new file mode 100644 index 00000000..32044c3f --- /dev/null +++ b/infini_train/src/kernels/cuda/topk.cu @@ -0,0 +1,155 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void TopKForwardKernel(const T *__restrict__ input, T *__restrict__ top_values, + int64_t *__restrict__ top_indices, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk, bool largest) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t idx = 0; idx < dim_size; ++idx) { + const float value = static_cast(input[outer * dim_size * inner_size + idx * inner_size + inner]); + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < dim_size; ++other_idx) { + const float other_value + = static_cast(input[outer * dim_size * inner_size + other_idx * inner_size + inner]); + const bool ranks_before = largest ? (other_value > value || (other_value == value && other_idx < idx)) + : (other_value < value || (other_value == value && other_idx < idx)); + if (ranks_before) { + ++rank; + } + } + if (rank < topk) { + const int64_t out_offset = outer * topk * inner_size + rank * inner_size + inner; + top_values[out_offset] = input[outer * dim_size * inner_size + idx * inner_size + inner]; + top_indices[out_offset] = idx; + } + } +} + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + TopKForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(top_values->DataPtr()), + static_cast(top_indices->DataPtr()), rows, dim_size, inner_size, topk, largest); + }, + "CUDA TopKForward"); + + return {top_values, top_indices}; +} + +template +__global__ void TopKBackwardKernel(const T *__restrict__ grad_values, const int64_t *__restrict__ indices, + T *__restrict__ grad_input, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = indices[out_offset]; + grad_input[outer * dim_size * inner_size + selected_idx * inner_size + inner] = grad_values[out_offset]; + } +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + auto device = grad_values->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMemsetAsync(grad_input->DataPtr(), 0, grad_input->SizeInBytes(), stream)); + + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + core::cuda::DispatchCudaFunc( + grad_values->Dtype(), + [=]() { + TopKBackwardKernel<<>>( + static_cast(grad_values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(grad_input->DataPtr()), rows, dim_size, inner_size, topk); + }, + "CUDA TopKBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOPK_KERNEL(TopKForward) +REGISTER_CUDA_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CUDA_TOPK_KERNEL diff --git a/infini_train/src/kernels/cuda/topk_mask.cu b/infini_train/src/kernels/cuda/topk_mask.cu deleted file mode 100644 index e38c793e..00000000 --- a/infini_train/src/kernels/cuda/topk_mask.cu +++ /dev/null @@ -1,118 +0,0 @@ -#include "glog/logging.h" - -#include "infini_train/include/common/cuda/common_cuda.h" -#include "infini_train/include/core/runtime/device_guard.h" -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" -#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" - -namespace infini_train::kernels::cuda { - -template -__global__ void TopKMaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, - int64_t num_experts, int64_t topk) { - int64_t row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= rows) { - return; - } - - const int64_t offset = row * num_experts; - float selected_sum = 0.0f; - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - const float value = static_cast(input[offset + expert_idx]); - int64_t rank = 0; - for (int64_t other_idx = 0; other_idx < num_experts; ++other_idx) { - const float other_value = static_cast(input[offset + other_idx]); - if (other_value > value || (other_value == value && other_idx < expert_idx)) { - ++rank; - } - } - const bool selected = rank < topk; - output[offset + expert_idx] = selected ? input[offset + expert_idx] : T(0.0f); - selected_sum += selected ? value : 0.0f; - } - if (topk > 1 && selected_sum != 0.0f) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - if (static_cast(output[offset + expert_idx]) != 0.0f) { - output[offset + expert_idx] = T(static_cast(output[offset + expert_idx]) / selected_sum); - } - } - } -} - -std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { - CHECK_GE(input->Dims().size(), 1); - const auto &dims = input->Dims(); - const int64_t num_experts = dims.back(); - CHECK_GT(num_experts, 0); - CHECK_GT(topk, 0); - CHECK_LE(topk, num_experts); - const int64_t rows = input->NumElements() / num_experts; - - auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); - - auto device = input->GetDevice(); - const auto &stream = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) - ->cuda_stream(); - const int threads = 256; - const int blocks = static_cast((rows + threads - 1) / threads); - - core::cuda::DispatchCudaFunc( - input->Dtype(), - [=]() { - TopKMaskForwardKernel<<>>( - static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts, topk); - }, - "CUDA TopKMaskForward"); - - return output; -} - -template -__global__ void TopKMaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, - T *__restrict__ grad_input, int64_t total_elements) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) { - return; - } - grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); -} - -std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &mask_values) { - CHECK(grad_output->Dims() == mask_values->Dims()); - CHECK(grad_output->Dtype() == mask_values->Dtype()); - auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); - - auto device = grad_output->GetDevice(); - const auto &stream = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) - ->cuda_stream(); - const int64_t total_elements = grad_output->NumElements(); - const int threads = 256; - const int blocks = static_cast((total_elements + threads - 1) / threads); - - core::cuda::DispatchCudaFunc( - grad_output->Dtype(), - [=]() { - TopKMaskBackwardKernel<<>>( - static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), - static_cast(grad_input->DataPtr()), total_elements); - }, - "CUDA TopKMaskBackward"); - - return grad_input; -} - -} // namespace infini_train::kernels::cuda - -#define REGISTER_CUDA_TOPK_MASK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) - -REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskForward) -REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskBackward) - -#undef REGISTER_CUDA_TOPK_MASK_KERNEL From 8d5c406f1f69e9e8bb85f8008f40901839e9b048 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 27 May 2026 09:28:27 +0000 Subject: [PATCH 6/9] refactor: refactor TopKRouter module to align with Megatron interface --- .../nn/modules/transformer/moe/moe_utils.h | 10 ++++ .../modules/transformer/transformer_config.h | 31 ++++++----- .../nn/modules/transformer/moe/moe_layer.cc | 9 ++-- .../nn/modules/transformer/moe/moe_utils.cc | 52 +++++++++++++++++++ .../src/nn/modules/transformer/moe/router.cc | 17 +++--- 5 files changed, 95 insertions(+), 24 deletions(-) diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h index e0dd3744..6ce26f44 100644 --- a/infini_train/include/nn/modules/transformer/moe/moe_utils.h +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -1,9 +1,19 @@ #pragma once +#include +#include +#include + #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/tensor.h" namespace infini_train::nn::moe { +std::vector> TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, + bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function); + const MoEConfig &RequireMoEConfig(const TransformerConfig &config); } // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 2072acb6..d2de5e74 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -25,30 +25,33 @@ enum class NormType { kRMSNorm // RMSNorm }; -enum class MoERouterType { - kTopK // Top-k router. -}; +struct MoEConfig { + enum class RouterScoreFunction { + kSoftmax, + kSigmoid, + }; -enum class MoEDispatcherType { - kLocal, // No cross-rank token exchange - kAllGather // Reserved for expert parallel MoE -}; + enum class DispatcherType { + kAllGather, // Megatron-style AllGather dispatcher. Degenerates to local dispatch when TP=EP=1. + kAllToAll // Megatron-style AllToAll dispatcher for expert parallel MoE. + }; -enum class MoEExpertImpl { - kSequential // Run local experts sequentially -}; + enum class ExpertImpl { + kSequential // Run local experts sequentially + }; -struct MoEConfig { int64_t num_experts = 0; int64_t expert_parallel_size = 1; int64_t router_topk = 1; + bool router_pre_softmax = false; + std::optional router_topk_scaling_factor = std::nullopt; + RouterScoreFunction router_score_function = RouterScoreFunction::kSoftmax; float aux_loss_coeff = 0.0f; std::optional expert_capacity_factor = std::nullopt; bool pad_expert_input_to_capacity = false; int64_t moe_ffn_hidden_size = 0; - MoERouterType router_type = MoERouterType::kTopK; - MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal; - MoEExpertImpl expert_impl = MoEExpertImpl::kSequential; + DispatcherType dispatcher_type = DispatcherType::kAllGather; + ExpertImpl expert_impl = ExpertImpl::kSequential; }; struct TransformerConfig { diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc index 8efd51c0..6add37ef 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -15,8 +15,8 @@ namespace infini_train::nn::moe { MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); CHECK(config_.ffn_type == FFNType::kMoE); - CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) - << "Current InfiniTrain MoE implementation supports local dispatch only"; + CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; modules_[kRouterLayerName] = std::make_shared(config_); modules_[kExpertsLayerName] = std::make_shared(config_); @@ -25,8 +25,9 @@ MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), co std::vector> MoELayer::Forward(const std::vector> &input_tensors) { CHECK_EQ(input_tensors.size(), 1); auto hidden_states = input_tensors[0]; - auto routing_probs = (*modules_.at(kRouterLayerName))({hidden_states})[0]; - return (*modules_.at(kExpertsLayerName))({hidden_states, routing_probs}); + auto router_output = (*modules_.at(kRouterLayerName))({hidden_states}); + CHECK_EQ(router_output.size(), 2); + return (*modules_.at(kExpertsLayerName))({hidden_states, router_output[0], router_output[1]}); } } // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc index 80ef01c1..976e9eff 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -2,8 +2,60 @@ #include "glog/logging.h" +#include "infini_train/include/autograd/local_token_dispatcher.h" +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/topk.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/nn/functional.h" + namespace infini_train::nn::moe { +std::vector> +TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function) { + + // Megatron TopKRouter returns dense tensors: + // routing_probs: [num_tokens, num_experts] + // routing_map: [num_tokens, num_experts], bool + std::shared_ptr top_probs; + std::shared_ptr top_indices; + + if (score_function == MoEConfig::RouterScoreFunction::kSoftmax) { + if (use_pre_softmax) { + auto scores = function::Softmax(logits, -1); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({scores})[0]; + top_indices = topk_function->TopIndices(); + } else { + auto topk_function = std::make_shared(topk); + auto top_scores = topk_function->Apply({logits})[0]; + top_indices = topk_function->TopIndices(); + top_probs = function::Softmax(top_scores, -1); + } + } else if (score_function == MoEConfig::RouterScoreFunction::kSigmoid) { + auto sigmoid_scores = function::Sigmoid(logits); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({sigmoid_scores})[0]; + top_indices = topk_function->TopIndices(); + if (topk > 1) { + top_probs = top_probs / (top_probs->Sum(-1, true) + 1e-20f); + } + } else { + LOG(FATAL) << "Unsupported MoE router score function"; + } + + if (scaling_factor.has_value()) { + top_probs = top_probs * scaling_factor.value(); + } + + auto routing_probs = std::make_shared(logits->Dims())->Apply({top_probs, top_indices})[0]; + auto routing_map_values = std::make_shared(top_indices->Equals(top_indices)->To(DataType::kBOOL)); + auto routing_map = Dispatcher::Instance().Call>( + {logits->GetDevice().type(), "ScatterForward"}, routing_map_values, top_indices, logits->Dims()); + return {routing_probs, routing_map}; +} + const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; return config.moe_config.value(); diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc index 851c57be..25208684 100644 --- a/infini_train/src/nn/modules/transformer/moe/router.cc +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -6,7 +6,8 @@ #include "glog/logging.h" #include "infini_train/include/autograd/linear.h" -#include "infini_train/include/autograd/topk_mask.h" +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/topk.h" #include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" @@ -16,11 +17,9 @@ namespace infini_train::nn::moe { TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); - CHECK(moe_config.router_type == MoERouterType::kTopK); CHECK_GT(moe_config.num_experts, 0); CHECK_GT(moe_config.router_topk, 0); CHECK_LE(moe_config.router_topk, moe_config.num_experts); - parameters_[kParamWeightName] = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, device_) @@ -43,10 +42,16 @@ std::vector> TopKRouter::Forward(const std::vector()->Apply(linear_inputs)[0]; - auto scores = function::Softmax(logits, -1); + const auto &moe_config = RequireMoEConfig(config_); - auto routing_probs = std::make_shared(moe_config.router_topk)->Apply({scores})[0]; - return {routing_probs}; + + auto routing_results + = TopkRoutingWithScoreFunction(logits, moe_config.router_topk, moe_config.router_pre_softmax, + moe_config.router_topk_scaling_factor, moe_config.router_score_function); + + auto routing_probs = routing_results[0]; + auto routing_map = routing_results[1]; + return {routing_probs, routing_map}; } } // namespace infini_train::nn::moe From dd7390b4b6357bbdd1e2839968b28e62db56ad48 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 29 May 2026 09:33:12 +0000 Subject: [PATCH 7/9] feat: implement MoETokenDispatcher base class and MoEAllGatherTokenDispatcher --- infini_train/include/autograd/scatter_add.h | 31 +++++ .../nn/modules/transformer/moe/moe_utils.h | 21 ++++ .../transformer/moe/token_dispatcher.h | 67 ++++++++++ infini_train/src/autograd/scatter_add.cc | 35 ++++++ infini_train/src/kernels/cpu/concat.cc | 16 +-- infini_train/src/kernels/cpu/transform.cc | 11 +- .../src/nn/modules/transformer/moe/experts.cc | 39 ++++-- .../nn/modules/transformer/moe/moe_utils.cc | 118 +++++++++++++++++- .../transformer/moe/token_dispatcher.cc | 95 ++++++++++++++ 9 files changed, 408 insertions(+), 25 deletions(-) create mode 100644 infini_train/include/autograd/scatter_add.h create mode 100644 infini_train/include/nn/modules/transformer/moe/token_dispatcher.h create mode 100644 infini_train/src/autograd/scatter_add.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc diff --git a/infini_train/include/autograd/scatter_add.h b/infini_train/include/autograd/scatter_add.h new file mode 100644 index 00000000..3adc1586 --- /dev/null +++ b/infini_train/include/autograd/scatter_add.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class ScatterAdd : public Function { +public: + static constexpr char kType[] = "ScatterAddFunction"; + + ScatterAdd(int64_t dim, const std::vector &output_dims) + : Function(kType), dim_(dim), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t dim_ = 0; + std::vector output_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h index 6ce26f44..f6941049 100644 --- a/infini_train/include/nn/modules/transformer/moe/moe_utils.h +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -9,11 +9,32 @@ namespace infini_train::nn::moe { +struct PermutationMetadata { + std::shared_ptr sorted_indices; + std::shared_ptr gather_indices; + std::shared_ptr route_indices; + std::shared_ptr tokens_per_expert; + std::vector tokens_per_expert_host; +}; + +struct PermutationResult { + std::shared_ptr permuted_hidden_states; + std::shared_ptr permuted_probs; + PermutationMetadata metadata; +}; + std::vector> TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, bool use_pre_softmax, std::optional scaling_factor, const MoEConfig::RouterScoreFunction &score_function); const MoEConfig &RequireMoEConfig(const TransformerConfig &config); +PermutationMetadata BuildPermutationMetadata(const std::shared_ptr &routing_map); +PermutationResult Permute(const std::shared_ptr &hidden_states_2d, + const std::shared_ptr &routing_probs_2d, + const std::shared_ptr &routing_map_2d); +std::shared_ptr Unpermute(const std::shared_ptr &permuted_hidden_states, + const std::shared_ptr &permuted_probs, const PermutationMetadata &metadata, + const std::vector &restore_shape); } // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h b/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h new file mode 100644 index 00000000..f9e3c614 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +class MoETokenDispatcher { +public: + virtual ~MoETokenDispatcher() = default; + + const PermutationResult &Dispatch(const std::shared_ptr &tokens, const std::shared_ptr &routing_map, + const std::shared_ptr &probs); + std::shared_ptr Combine(const std::shared_ptr &hidden_states) const; + +protected: + explicit MoETokenDispatcher(const TransformerConfig &config); + + virtual std::vector> DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) + = 0; + virtual std::vector> TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const + = 0; + virtual const PermutationResult &DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) + = 0; + virtual std::shared_ptr CombinePreprocess(const std::shared_ptr &hidden_states) const = 0; + virtual std::shared_ptr TokenCombine(const std::shared_ptr &hidden_states) const = 0; + virtual std::shared_ptr CombinePostprocess(const std::shared_ptr &hidden_states) const = 0; + + TransformerConfig config_; + PermutationResult dispatch_; + std::vector hidden_dims_; + std::shared_ptr routing_map_; + std::shared_ptr local_map_; + std::shared_ptr local_probs_; + int64_t num_tokens_ = 0; + int64_t hidden_size_ = 0; +}; + +class MoEAllGatherTokenDispatcher : public MoETokenDispatcher { +public: + MoEAllGatherTokenDispatcher(int64_t num_local_experts, const TransformerConfig &config); + +private: + std::vector> DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) override; + std::vector> TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const override; + const PermutationResult &DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) override; + std::shared_ptr CombinePreprocess(const std::shared_ptr &hidden_states) const override; + std::shared_ptr TokenCombine(const std::shared_ptr &hidden_states) const override; + std::shared_ptr CombinePostprocess(const std::shared_ptr &hidden_states) const override; + + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/autograd/scatter_add.cc b/infini_train/src/autograd/scatter_add.cc new file mode 100644 index 00000000..428f4f08 --- /dev/null +++ b/infini_train/src/autograd/scatter_add.cc @@ -0,0 +1,35 @@ +#include "infini_train/include/autograd/scatter_add.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> ScatterAdd::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &values = input_tensors[0]; + const auto &indices = input_tensors[1]; + auto device = values->GetDevice().type(); + auto output = Dispatcher::Instance().Call>({device, "GatherBackward"}, values, indices, + dim_, output_dims_); + return {output}; +} + +void ScatterAdd::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + saved_tensors_ = {input_tensors[1]}; +} + +std::vector> ScatterAdd::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &indices = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + auto grad_values + = Dispatcher::Instance().Call>({device, "GatherForward"}, grad_output, indices, dim_); + return {grad_values, nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/concat.cc b/infini_train/src/kernels/cpu/concat.cc index b421063f..169cc557 100644 --- a/infini_train/src/kernels/cpu/concat.cc +++ b/infini_train/src/kernels/cpu/concat.cc @@ -1,7 +1,6 @@ -#include +#include #include #include -#include #include #include "glog/logging.h" @@ -42,23 +41,24 @@ std::shared_ptr ConcatForward(const std::vector> const int64_t K_total = std::accumulate(Ks.begin(), Ks.end(), int64_t{0}); output_dims[dim] = K_total; - auto output = std::make_shared(output_dims, DataType::kFLOAT32); + auto output = std::make_shared(output_dims, dtype, device); const int64_t outer_size = std::accumulate(output_dims.begin(), output_dims.begin() + dim, 1LL, std::multiplies()); const int64_t inner_size = std::accumulate(output_dims.begin() + dim + 1, output_dims.end(), 1LL, std::multiplies()); - const size_t elem_size = sizeof(float); + const size_t elem_size = kDataTypeToSize.at(dtype); - float *dst_ptr_base = static_cast(output->DataPtr()); + auto *dst_ptr_base = static_cast(output->DataPtr()); for (int64_t n = 0; n < outer_size; ++n) { int64_t offset_k = 0; - float *dst_block = dst_ptr_base + n * K_total * inner_size; + auto *dst_block = dst_ptr_base + n * K_total * inner_size * elem_size; for (size_t i = 0; i < inputs.size(); ++i) { const int64_t Ki = Ks[i]; - const float *src_ptr = static_cast(inputs[i]->DataPtr()) + n * Ki * inner_size; - float *dst_ptr = dst_block + offset_k * inner_size; + const auto *src_ptr + = static_cast(inputs[i]->DataPtr()) + n * Ki * inner_size * elem_size; + auto *dst_ptr = dst_block + offset_k * inner_size * elem_size; std::memcpy(dst_ptr, src_ptr, static_cast(Ki) * inner_size * elem_size); offset_k += Ki; } diff --git a/infini_train/src/kernels/cpu/transform.cc b/infini_train/src/kernels/cpu/transform.cc index 1a810b44..48063c7a 100644 --- a/infini_train/src/kernels/cpu/transform.cc +++ b/infini_train/src/kernels/cpu/transform.cc @@ -1,4 +1,6 @@ #include +#include +#include #include #include "glog/logging.h" @@ -167,14 +169,15 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i output_dims[dim] = dim_size * repeat; auto output = std::make_shared(output_dims, input->Dtype(), input->GetDevice()); - const float *input_ptr = static_cast(input->DataPtr()); - float *output_ptr = static_cast(output->DataPtr()); + const size_t elem_size = kDataTypeToSize.at(input->Dtype()); + const auto *input_ptr = static_cast(input->DataPtr()); + auto *output_ptr = static_cast(output->DataPtr()); for (int64_t o = 0; o < outer; ++o) { for (int64_t i = 0; i < dim_size; ++i) { for (int r = 0; r < repeat; ++r) { - std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner), - input_ptr + ((o * dim_size + i) * inner), sizeof(float) * inner); + std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner * elem_size), + input_ptr + ((o * dim_size + i) * inner * elem_size), elem_size * inner); } } } diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc index 8f3b1be8..7566c48f 100644 --- a/infini_train/src/nn/modules/transformer/moe/experts.cc +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -6,19 +6,21 @@ #include "glog/logging.h" +#include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/token_dispatcher.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::moe { SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); - CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential); + CHECK(moe_config.expert_impl == MoEConfig::ExpertImpl::kSequential); CHECK_EQ(moe_config.expert_parallel_size, 1) << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; - CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) - << "Current InfiniTrain MoE implementation supports local dispatch only"; + CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; num_local_experts_ = moe_config.num_experts; CHECK_GT(num_local_experts_, 0); @@ -29,22 +31,35 @@ SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule( } std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 2); + CHECK_EQ(input_tensors.size(), 3); auto hidden_states = input_tensors[0]; auto routing_probs = input_tensors[1]; - CHECK_EQ(routing_probs->Dims().back(), num_local_experts_); + auto routing_map = input_tensors[2]; + std::unique_ptr dispatcher + = std::make_unique(num_local_experts_, config_); + const auto &dispatch = dispatcher->Dispatch(hidden_states, routing_map, routing_probs); - std::shared_ptr output = nullptr; - const int64_t expert_dim = static_cast(routing_probs->Dims().size()) - 1; + std::vector> expert_outputs; + int64_t start = 0; for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + const int64_t num_tokens_for_expert = dispatch.metadata.tokens_per_expert_host[expert_idx]; + const int64_t end = start + num_tokens_for_expert; + if (num_tokens_for_expert == 0) { + start = end; + continue; + } + + auto expert_input = dispatch.permuted_hidden_states->Slice(0, start, end); auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); - auto expert_output = (*modules_.at(expert_name))({hidden_states})[0]; - auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1); - auto weighted_output = expert_output * expert_prob; - output = output == nullptr ? weighted_output : output + weighted_output; + expert_outputs.push_back((*modules_.at(expert_name))({expert_input})[0]); + start = end; } + CHECK_EQ(start, dispatch.permuted_hidden_states->Dims()[0]); + CHECK(!expert_outputs.empty()) << "No tokens were dispatched to any local expert"; - return {output}; + auto permuted_expert_output + = expert_outputs.size() == 1 ? expert_outputs[0] : nn::function::Concat(expert_outputs, 0); + return {dispatcher->Combine(permuted_expert_output)}; } } // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc index 976e9eff..040b29df 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -1,9 +1,11 @@ #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include + #include "glog/logging.h" -#include "infini_train/include/autograd/local_token_dispatcher.h" #include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/scatter_add.h" #include "infini_train/include/autograd/topk.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/functional.h" @@ -61,4 +63,118 @@ const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { return config.moe_config.value(); } +PermutationMetadata BuildPermutationMetadata(const std::shared_ptr &routing_map) { + CHECK(routing_map->Dtype() == DataType::kBOOL); + CHECK_EQ(routing_map->Dims().size(), 2); + + const int64_t num_tokens = routing_map->Dims()[0]; + const int64_t num_experts = routing_map->Dims()[1]; + CHECK_GT(num_tokens, 0); + CHECK_GT(num_experts, 0); + + Tensor routing_map_cpu_storage = routing_map->To(Device()); + auto routing_map_cpu = std::make_shared(routing_map_cpu_storage); + const auto *routing_map_ptr = static_cast(routing_map_cpu->DataPtr()); + + std::vector sorted_indices_host; + std::vector route_indices_host; + std::vector tokens_per_expert_host; + sorted_indices_host.reserve(routing_map->NumElements()); + route_indices_host.reserve(routing_map->NumElements()); + tokens_per_expert_host.reserve(num_experts); + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + int64_t tokens_for_expert = 0; + for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + if (routing_map_ptr[token_idx * num_experts + expert_idx]) { + sorted_indices_host.push_back(token_idx); + route_indices_host.push_back(token_idx * num_experts + expert_idx); + ++tokens_for_expert; + } + } + tokens_per_expert_host.push_back(tokens_for_expert); + } + + const int64_t num_dispatched_tokens = static_cast(sorted_indices_host.size()); + auto sorted_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens}, DataType::kINT64, Device()); + auto route_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens}, DataType::kINT64, Device()); + auto gather_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens, 1}, DataType::kINT64, Device()); + auto tokens_per_expert_cpu + = std::make_shared(std::vector{num_experts}, DataType::kINT64, Device()); + + auto *sorted_indices_ptr = static_cast(sorted_indices_cpu->DataPtr()); + auto *route_indices_ptr = static_cast(route_indices_cpu->DataPtr()); + auto *gather_indices_ptr = static_cast(gather_indices_cpu->DataPtr()); + auto *tokens_per_expert_ptr = static_cast(tokens_per_expert_cpu->DataPtr()); + for (int64_t idx = 0; idx < num_dispatched_tokens; ++idx) { + sorted_indices_ptr[idx] = sorted_indices_host[idx]; + route_indices_ptr[idx] = route_indices_host[idx]; + gather_indices_ptr[idx] = sorted_indices_host[idx]; + } + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + tokens_per_expert_ptr[expert_idx] = tokens_per_expert_host[expert_idx]; + } + + auto to_device = [&](const std::shared_ptr &cpu_tensor) -> std::shared_ptr { + if (routing_map->GetDevice().type() == Device::DeviceType::kCPU) { + return cpu_tensor; + } + return std::make_shared(cpu_tensor->To(routing_map->GetDevice())); + }; + + return {to_device(sorted_indices_cpu), to_device(gather_indices_cpu), to_device(route_indices_cpu), + to_device(tokens_per_expert_cpu), tokens_per_expert_host}; +} + +PermutationResult Permute(const std::shared_ptr &hidden_states_2d, + const std::shared_ptr &routing_probs_2d, + const std::shared_ptr &routing_map_2d) { + CHECK_EQ(hidden_states_2d->Dims().size(), 2); + CHECK(routing_probs_2d->Dims() == routing_map_2d->Dims()); + CHECK(routing_map_2d->Dtype() == DataType::kBOOL); + + const int64_t hidden_size = hidden_states_2d->Dims()[1]; + auto metadata = BuildPermutationMetadata(routing_map_2d); + const int64_t num_dispatched_tokens = metadata.sorted_indices->Dims()[0]; + + std::shared_ptr permuted_hidden_states; + std::shared_ptr permuted_probs; + if (num_dispatched_tokens == 0) { + permuted_hidden_states = std::make_shared(std::vector{0, hidden_size}, + hidden_states_2d->Dtype(), hidden_states_2d->GetDevice()); + permuted_probs = std::make_shared(std::vector{0}, routing_probs_2d->Dtype(), + routing_probs_2d->GetDevice()); + } else { + auto gather_indices = metadata.gather_indices; + if (hidden_size != 1) { + gather_indices = metadata.gather_indices->RepeatInterleave(hidden_size, 1); + } + permuted_hidden_states = hidden_states_2d->Gather(0, gather_indices); + permuted_probs = routing_probs_2d->View({static_cast(routing_probs_2d->NumElements())}) + ->Gather(0, metadata.route_indices); + } + + return {permuted_hidden_states, permuted_probs, metadata}; +} + +std::shared_ptr Unpermute(const std::shared_ptr &permuted_hidden_states, + const std::shared_ptr &permuted_probs, const PermutationMetadata &metadata, + const std::vector &restore_shape) { + CHECK_EQ(permuted_hidden_states->Dims().size(), 2); + CHECK_EQ(permuted_probs->Dims().size(), 1); + CHECK_EQ(permuted_hidden_states->Dims()[0], permuted_probs->Dims()[0]); + CHECK_EQ(restore_shape.size(), 2); + + auto weighted = permuted_hidden_states * permuted_probs->View({permuted_probs->Dims()[0], 1}); + auto scatter_indices = metadata.gather_indices; + const int64_t hidden_size = restore_shape[1]; + if (hidden_size != 1) { + scatter_indices = metadata.gather_indices->RepeatInterleave(hidden_size, 1); + } + return std::make_shared(0, restore_shape)->Apply({weighted, scatter_indices})[0]; +} + } // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc b/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc new file mode 100644 index 00000000..667dba8f --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc @@ -0,0 +1,95 @@ +#include "infini_train/include/nn/modules/transformer/moe/token_dispatcher.h" + +#include +#include + +#include "glog/logging.h" + +namespace infini_train::nn::moe { + +MoETokenDispatcher::MoETokenDispatcher(const TransformerConfig &config) : config_(config) {} + +const PermutationResult &MoETokenDispatcher::Dispatch(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) { + auto preprocessed = DispatchPreprocess(tokens, routing_map, probs); + auto dispatched = TokenDispatch(preprocessed[0], preprocessed[1]); + return DispatchPostprocess(dispatched[0], dispatched[1]); +} + +std::shared_ptr MoETokenDispatcher::Combine(const std::shared_ptr &hidden_states) const { + auto preprocessed = CombinePreprocess(hidden_states); + auto combined = TokenCombine(preprocessed); + return CombinePostprocess(combined); +} + +MoEAllGatherTokenDispatcher::MoEAllGatherTokenDispatcher(int64_t num_local_experts, const TransformerConfig &config) + : MoETokenDispatcher(config), num_local_experts_(num_local_experts) { + CHECK_GT(num_local_experts_, 0); +} + +std::vector> +MoEAllGatherTokenDispatcher::DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) { + CHECK(probs->Dims() == routing_map->Dims()); + CHECK(routing_map->Dtype() == DataType::kBOOL); + CHECK_GE(tokens->Dims().size(), 2); + + hidden_dims_ = tokens->Dims(); + hidden_size_ = hidden_dims_.back(); + CHECK_GT(hidden_size_, 0); + num_tokens_ = tokens->NumElements() / hidden_size_; + CHECK_EQ(probs->Dims().back(), num_local_experts_); + CHECK_EQ(probs->NumElements(), static_cast(num_tokens_ * num_local_experts_)); + + routing_map_ = routing_map->View({num_tokens_, num_local_experts_}); + auto hidden_states_2d = tokens->View({num_tokens_, hidden_size_}); + auto probs_2d = probs->View({num_tokens_, num_local_experts_}); + return {hidden_states_2d, probs_2d}; +} + +std::vector> +MoEAllGatherTokenDispatcher::TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const { + // AllGather dispatcher will gather tokens across TP*EP ranks here. For the current single-rank + // path (tp_size=1, ep_size=1), no communication is required. + return {hidden_states, probs}; +} + +const PermutationResult &MoEAllGatherTokenDispatcher::DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) { + CHECK(routing_map_ != nullptr); + CHECK_EQ(hidden_states->Dims().size(), 2); + CHECK_EQ(probs->Dims().size(), 2); + CHECK_EQ(hidden_states->Dims()[0], probs->Dims()[0]); + CHECK_EQ(probs->Dims()[1], num_local_experts_); + + // With ep_size=1 all experts are local, so the local expert map/probs are the gathered map/probs. + // Future EP support should slice [local_expert_start, local_expert_end) after AllGather. + local_map_ = routing_map_; + local_probs_ = probs; + dispatch_ = Permute(hidden_states, local_probs_, local_map_); + routing_map_ = nullptr; + return dispatch_; +} + +std::shared_ptr +MoEAllGatherTokenDispatcher::CombinePreprocess(const std::shared_ptr &hidden_states) const { + CHECK(local_map_ != nullptr); + CHECK(local_probs_ != nullptr); + return Unpermute(hidden_states, dispatch_.permuted_probs, dispatch_.metadata, + std::vector{num_tokens_, hidden_size_}); +} + +std::shared_ptr MoEAllGatherTokenDispatcher::TokenCombine(const std::shared_ptr &hidden_states) const { + // AllGather dispatcher will reduce-scatter combined token outputs here. For ep_size=1 this is a no-op. + return hidden_states; +} + +std::shared_ptr +MoEAllGatherTokenDispatcher::CombinePostprocess(const std::shared_ptr &hidden_states) const { + return hidden_states->View(hidden_dims_); +} + +} // namespace infini_train::nn::moe From ba7c33bac1c47bf19695d58692a467ac8f59416f Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 2 Jun 2026 07:46:09 +0000 Subject: [PATCH 8/9] feat: add tiny_mixtral example --- CMakeLists.txt | 8 + example/tiny_mixtral/checkpoint_loader.cc | 167 ++++++++++++++++++ example/tiny_mixtral/checkpoint_loader.h | 21 +++ example/tiny_mixtral/config.h | 76 ++++++++ example/tiny_mixtral/main.cc | 136 ++++++++++++++ .../modules/transformer/transformer_config.h | 4 +- .../src/nn/modules/transformer/moe/experts.cc | 2 +- .../nn/modules/transformer/moe/moe_layer.cc | 2 +- 8 files changed, 412 insertions(+), 4 deletions(-) create mode 100644 example/tiny_mixtral/checkpoint_loader.cc create mode 100644 example/tiny_mixtral/checkpoint_loader.h create mode 100644 example/tiny_mixtral/config.h create mode 100644 example/tiny_mixtral/main.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index e25de71d..cc3d5ee1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,6 +199,14 @@ add_executable(gpt2 ) link_infini_train_exe(gpt2) +add_executable(tiny_mixtral + example/tiny_mixtral/main.cc + example/common/tiny_shakespeare_dataset.cc + example/common/utils.cc + example/tiny_mixtral/checkpoint_loader.cc +) +link_infini_train_exe(tiny_mixtral) + add_executable(llama3 example/llama3/main.cc example/common/tiny_shakespeare_dataset.cc diff --git a/example/tiny_mixtral/checkpoint_loader.cc b/example/tiny_mixtral/checkpoint_loader.cc new file mode 100644 index 00000000..1e27ac53 --- /dev/null +++ b/example/tiny_mixtral/checkpoint_loader.cc @@ -0,0 +1,167 @@ +#include "example/tiny_mixtral/checkpoint_loader.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/datatype.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/tensor.h" + +#include "example/common/utils.h" +#include "example/tiny_mixtral/config.h" + +namespace nn = infini_train::nn; + +namespace { + +constexpr int32_t kTinyMixtralLLMCMagic = 20260513; +constexpr int32_t kTinyMixtralLLMCVersion = 2; +constexpr int64_t kLLMCHeaderEntries = 256; + +} // namespace + +namespace tiny_mixtral { + +namespace { + +template +void CompareCheckpointValue(const std::string &name, const T &checkpoint_value, const T &runtime_value) { + CHECK_EQ(checkpoint_value, runtime_value) << name << " value from checkpoint (" << checkpoint_value + << ") is not equal to runtime config value (" << runtime_value << ")"; +} + +} // namespace + +nn::TransformerConfig ConfigFromLLMC(const std::string &filepath) { + std::ifstream ifs(filepath, std::ios::binary); + CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath; + const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs); + CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath; + CHECK_EQ(infini_train::BytesToType(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic); + CHECK_EQ(infini_train::BytesToType(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion); + + auto config = TinyMixtralConfig(); + config.block_size = infini_train::BytesToType(header, 2 * sizeof(int32_t)); + config.vocab_size = infini_train::BytesToType(header, 3 * sizeof(int32_t)); + config.original_vocab_size = config.vocab_size; + config.n_layer = infini_train::BytesToType(header, 4 * sizeof(int32_t)); + config.n_head = infini_train::BytesToType(header, 5 * sizeof(int32_t)); + config.n_kv_head = infini_train::BytesToType(header, 6 * sizeof(int32_t)); + config.n_embd = infini_train::BytesToType(header, 7 * sizeof(int32_t)); + config.ffn_expansion_ratio = infini_train::BytesToType(header, 9 * sizeof(int32_t)); + // Header slots 10 and 11 store dense-MLP helpers; MoE expert size is stored in moe_ffn_hidden_size. + config.norm_eps = infini_train::BytesToType(header, 12 * sizeof(int32_t)); + config.rope_theta = infini_train::BytesToType(header, 13 * sizeof(int32_t)); + config.use_scaled_rope = infini_train::BytesToType(header, 14 * sizeof(int32_t)) != 0; + + nn::MoEConfig moe_config; + moe_config.num_experts = infini_train::BytesToType(header, 8 * sizeof(int32_t)); + moe_config.expert_parallel_size = 1; + moe_config.router_topk = infini_train::BytesToType(header, 15 * sizeof(int32_t)); + moe_config.moe_ffn_hidden_size = infini_train::BytesToType(header, 16 * sizeof(int32_t)); + moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; + moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; + config.moe_config = moe_config; + SanitizeTinyMixtralConfig(config); + return config; +} + +void CheckLLMCConfig(const std::string &filepath, const nn::TransformerConfig &expected_config) { + SanitizeTinyMixtralConfig(expected_config); + const auto checkpoint_config = ConfigFromLLMC(filepath); + CompareCheckpointValue("block_size", checkpoint_config.block_size, expected_config.block_size); + CompareCheckpointValue("vocab_size", checkpoint_config.vocab_size, expected_config.vocab_size); + CompareCheckpointValue("original_vocab_size", checkpoint_config.original_vocab_size, + expected_config.original_vocab_size); + CompareCheckpointValue("n_layer", checkpoint_config.n_layer, expected_config.n_layer); + CompareCheckpointValue("n_head", checkpoint_config.n_head, expected_config.n_head); + CompareCheckpointValue("n_kv_head", checkpoint_config.n_kv_head, expected_config.n_kv_head); + CompareCheckpointValue("n_embd", checkpoint_config.n_embd, expected_config.n_embd); + CompareCheckpointValue("ffn_expansion_ratio", checkpoint_config.ffn_expansion_ratio, + expected_config.ffn_expansion_ratio); + CompareCheckpointValue("norm_eps", checkpoint_config.norm_eps, expected_config.norm_eps); + CompareCheckpointValue("rope_theta", checkpoint_config.rope_theta, expected_config.rope_theta); + CompareCheckpointValue("use_scaled_rope", checkpoint_config.use_scaled_rope, expected_config.use_scaled_rope); + + CHECK(expected_config.moe_config.has_value()) << "tiny Mixtral runtime config requires MoE config"; + const auto &checkpoint_moe = checkpoint_config.moe_config.value(); + const auto &expected_moe = expected_config.moe_config.value(); + CompareCheckpointValue("num_experts", checkpoint_moe.num_experts, expected_moe.num_experts); + CompareCheckpointValue("router_topk", checkpoint_moe.router_topk, expected_moe.router_topk); + CompareCheckpointValue("moe_ffn_hidden_size", checkpoint_moe.moe_ffn_hidden_size, expected_moe.moe_ffn_hidden_size); +} + +std::shared_ptr LoadFromLLMC(const std::string &filepath, + const nn::TransformerConfig &expected_config) { + CheckLLMCConfig(filepath, expected_config); + auto model = std::make_shared(expected_config); + + std::ifstream ifs(filepath, std::ios::binary); + CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath; + const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs); + CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath; + CHECK_EQ(infini_train::BytesToType(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic); + CHECK_EQ(infini_train::BytesToType(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion); + + const auto &config = expected_config; + auto state = model->StateDict(); + auto read_tensor_by_state_key = [&](const std::string &name) { + CHECK(state.contains(name)) << "Model state_dict does not contain " << name; + std::shared_ptr tensor = state.at(name); + CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32) + << "Only float32 tiny Mixtral LLMC files are supported: " << name; + infini_train::ReadMatrixAllFloat(ifs, static_cast(tensor->DataPtr()), tensor->NumElements(), 1); + CHECK(ifs) << "Failed to read tensor " << name; + }; + + auto read_projection_into_packed_qkv = [&](const std::string &packed_qkv_name, int64_t row_offset, int64_t num_rows, + const std::string &projection_name) { + CHECK(state.contains(packed_qkv_name)) << "Model state_dict does not contain " << packed_qkv_name; + std::shared_ptr tensor = state.at(packed_qkv_name); + CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32) + << "Only float32 tiny Mixtral LLMC files are supported: " << projection_name; + CHECK_EQ(tensor->Dims().size(), 2); + CHECK_GE(row_offset, 0); + CHECK_GT(num_rows, 0); + CHECK_LE(row_offset + num_rows, tensor->Dims()[0]); + const int64_t cols = tensor->Dims()[1]; + auto *data = static_cast(tensor->DataPtr()) + row_offset * cols; + infini_train::ReadMatrixAllFloat(ifs, data, num_rows, cols); + CHECK(ifs) << "Failed to read tensor rows " << projection_name; + }; + + const auto &moe_config = config.moe_config.value(); + read_tensor_by_state_key("transformer.wte.weight"); + for (int64_t layer = 0; layer < config.n_layer; ++layer) { + const std::string prefix = "transformer.h." + std::to_string(layer); + read_tensor_by_state_key(prefix + ".ln_1.weight"); + const auto c_attn_name = prefix + ".attn.c_attn.weight"; + const int64_t head_dim = config.n_embd / config.n_head; + const int64_t q_rows = config.n_head * head_dim; + const int64_t kv_rows = config.n_kv_head * head_dim; + read_projection_into_packed_qkv(c_attn_name, 0, q_rows, c_attn_name + ".q_proj"); + read_projection_into_packed_qkv(c_attn_name, q_rows, kv_rows, c_attn_name + ".k_proj"); + read_projection_into_packed_qkv(c_attn_name, q_rows + kv_rows, kv_rows, c_attn_name + ".v_proj"); + read_tensor_by_state_key(prefix + ".attn.c_proj.weight"); + read_tensor_by_state_key(prefix + ".ln_2.weight"); + read_tensor_by_state_key(prefix + ".mlp.router.weight"); + for (int64_t expert = 0; expert < moe_config.num_experts; ++expert) { + const std::string expert_prefix = prefix + ".mlp.experts.expert_" + std::to_string(expert); + read_tensor_by_state_key(expert_prefix + ".c_fc2.weight"); // Mixtral w1/gate_proj + read_tensor_by_state_key(expert_prefix + ".c_fc.weight"); // Mixtral w3/up_proj + read_tensor_by_state_key(expert_prefix + ".c_proj.weight"); // Mixtral w2/down_proj + } + } + read_tensor_by_state_key("transformer.ln_f.weight"); + read_tensor_by_state_key("lm_head.weight"); + + CHECK_EQ(ifs.peek(), std::ifstream::traits_type::eof()) << "Unexpected trailing bytes in tiny Mixtral LLMC file"; + return model; +} + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/checkpoint_loader.h b/example/tiny_mixtral/checkpoint_loader.h new file mode 100644 index 00000000..738538ad --- /dev/null +++ b/example/tiny_mixtral/checkpoint_loader.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn { +class TransformerModel; +} // namespace infini_train::nn + +namespace tiny_mixtral { + +infini_train::nn::TransformerConfig ConfigFromLLMC(const std::string &filepath); + +void CheckLLMCConfig(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config); + +std::shared_ptr +LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config); + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/config.h b/example/tiny_mixtral/config.h new file mode 100644 index 00000000..0d7096d4 --- /dev/null +++ b/example/tiny_mixtral/config.h @@ -0,0 +1,76 @@ +#pragma once + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace nn = infini_train::nn; + +namespace tiny_mixtral { + +inline nn::TransformerConfig TinyMixtralConfig() { + nn::TransformerConfig config; + config.block_size = 32768; // Same as Mixtral/Megatron --max-position-embeddings. + config.vocab_size = 128256; // Validation data uses LLaMA3 token ids; real Mixtral uses 32000. + config.original_vocab_size = 128256; + config.n_layer = 2; // Tiny scale; Megatron --num-layers 32. + config.n_head = 4; // Tiny scale; preserves the Megatron 4:1 GQA ratio. + config.n_kv_head = 1; // Tiny scale; Megatron --num-query-groups 8. + config.n_embd = 32; // Tiny scale; Megatron --hidden-size 4096. + config.attention_type = nn::AttentionType::kRoPE; + config.activation_type = nn::MLPType::kSwiGLU; + config.ffn_type = nn::FFNType::kMoE; + config.norm_type = nn::NormType::kRMSNorm; + config.add_bias_linear = false; + config.add_bias_lm_head = false; + config.tie_weights = false; + config.ffn_expansion_ratio = 3.5f; + config.norm_eps = 1e-5f; + config.rope_theta = 1000000.0f; + config.use_scaled_rope = false; + + nn::MoEConfig moe_config; + moe_config.num_experts = 8; + moe_config.expert_parallel_size = 1; // Single-rank validation scale. + moe_config.router_topk = 2; + moe_config.moe_ffn_hidden_size = 112; // Tiny scale; Megatron --ffn-hidden-size 14336. + moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; // Single-rank validation path. + moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; // Local correctness path. + config.moe_config = moe_config; + return config; +} + +inline void SanitizeTinyMixtralConfig(const nn::TransformerConfig &c) { + CHECK_GT(c.block_size, 0); + CHECK_GT(c.vocab_size, 0); + CHECK_GE(c.vocab_size, c.original_vocab_size); + CHECK_GT(c.n_layer, 0); + CHECK_GT(c.n_head, 0); + CHECK_GT(c.n_kv_head, 0); + CHECK_LE(c.n_kv_head, c.n_head); + CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA"; + CHECK_GT(c.n_embd, 0); + CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head"; + CHECK(c.attention_type == nn::AttentionType::kRoPE) << "tiny Mixtral requires RoPE attention"; + CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "tiny Mixtral requires SwiGLU activation"; + CHECK(c.ffn_type == nn::FFNType::kMoE) << "tiny Mixtral requires MoE FFN"; + CHECK(c.norm_type == nn::NormType::kRMSNorm) << "tiny Mixtral requires RMSNorm"; + CHECK(!c.add_bias_linear) << "tiny Mixtral has no bias in linear layers"; + CHECK(!c.add_bias_lm_head) << "tiny Mixtral has no bias in lm_head"; + CHECK(!c.tie_weights) << "tiny Mixtral does not tie embedding and lm_head weights"; + CHECK(!c.use_scaled_rope) << "tiny Mixtral precision validation keeps scaled RoPE disabled"; + CHECK(c.moe_config.has_value()) << "tiny Mixtral requires MoE config"; + + const auto &moe = c.moe_config.value(); + CHECK_GT(moe.num_experts, 0); + CHECK_EQ(moe.expert_parallel_size, 1) << "tiny Mixtral single-rank validation expects EP=1"; + CHECK_GT(moe.router_topk, 0); + CHECK_LE(moe.router_topk, moe.num_experts); + CHECK_GT(moe.moe_ffn_hidden_size, 0); + CHECK(moe.token_dispatcher_type == nn::MoEConfig::TokenDispatcherType::kAllGather) + << "tiny Mixtral uses the Megatron-style AllGather dispatcher"; + CHECK(moe.expert_impl == nn::MoEConfig::ExpertImpl::kSequential) + << "tiny Mixtral validation uses SequentialMLP experts"; +} + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/main.cc b/example/tiny_mixtral/main.cc new file mode 100644 index 00000000..dc2b5136 --- /dev/null +++ b/example/tiny_mixtral/main.cc @@ -0,0 +1,136 @@ +#include +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" + +#include "example/common/tiny_shakespeare_dataset.h" +#include "example/tiny_mixtral/checkpoint_loader.h" +#include "example/tiny_mixtral/config.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dataloader.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/modules/loss.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +DEFINE_string(input_bin, "", "input .bin to train on"); +DEFINE_uint32(batch_size, 4, "batch size"); +DEFINE_uint32(sequence_length, 64, "sequence length"); +DEFINE_uint32(num_iteration, 10, "number of training iterations"); +DEFINE_double(learning_rate, 1e-4, "SGD learning rate"); +DEFINE_string(llmc_filepath, "", + "optional PyTorch-generated tiny Mixtral LLMC model file path to load before training"); +DEFINE_string(device, "cpu", "Training device: cpu or cuda."); +DEFINE_uint32(log_interval, 1, "Print train loss every N steps. 0 disables step loss logging."); +DEFINE_bool(print_timing, false, "Print training-loop elapsed time and token throughput."); + +namespace { + +using infini_train::Device; +using infini_train::Tensor; + +void ValidateRuntimeFlags(const infini_train::nn::TransformerConfig &config) { + CHECK(!FLAGS_input_bin.empty()) << "tiny Mixtral training requires --input_bin"; + CHECK_GT(FLAGS_batch_size, 0); + CHECK_GT(FLAGS_sequence_length, 0); + CHECK_LE(FLAGS_sequence_length, config.block_size) << "sequence_length must be <= model max positions (block_size)"; +} + +} // namespace + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + infini_train::nn::parallel::global::InitAllEnv( + /*nthread_per_process=*/1, + /*tensor_parallel_size=*/1, + /*sequence_parallel_enabled=*/false, + /*pipeline_parallel_size=*/1, + /*virtual_pipeline_parallel_size=*/1); + + infini_train::nn::TransformerConfig model_config = tiny_mixtral::TinyMixtralConfig(); + tiny_mixtral::SanitizeTinyMixtralConfig(model_config); + std::shared_ptr model = nullptr; + if (!FLAGS_llmc_filepath.empty()) { + model = tiny_mixtral::LoadFromLLMC(FLAGS_llmc_filepath, model_config); + } else { + model = std::make_shared(model_config); + } + ValidateRuntimeFlags(model_config); + + Device train_device; + if (FLAGS_device == "cuda") { + train_device = Device(Device::DeviceType::kCUDA, 0); + model->To(train_device); + } else { + CHECK_EQ(FLAGS_device, "cpu") << "Unsupported training device: " << FLAGS_device; + train_device = Device(); + } + + infini_train::DistributedDataLoader train_loader( + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size, + /*ddp_rank=*/0, /*ddp_world_size=*/1); + auto train_iter = train_loader.begin(); + + auto loss_fn = std::make_shared(); + auto optimizer + = infini_train::optimizers::SGD::Create(static_cast(FLAGS_learning_rate))(model->Parameters()); + + auto device_impl = infini_train::core::GetDeviceGuardImpl(train_device.type()); + std::vector step_duration_ms; + step_duration_ms.reserve(FLAGS_num_iteration); + const double tokens_per_step = static_cast(FLAGS_batch_size) * FLAGS_sequence_length; + for (uint32_t step = 0; step < FLAGS_num_iteration; ++step) { + device_impl->SynchronizeDevice(train_device); + const auto step_start_time = std::chrono::steady_clock::now(); + + optimizer->ZeroGrad(); + if (train_iter == train_loader.end()) { + train_iter = train_loader.begin(); + } + auto [x_cpu, y_cpu] = *train_iter; + ++train_iter; + auto x = std::make_shared(x_cpu->To(train_device)); + auto y = std::make_shared(y_cpu->To(train_device)); + auto logits = (*model)({x})[0]; + auto loss = (*loss_fn)({logits, y})[0]; + loss->Backward(); + optimizer->Step(); + + device_impl->SynchronizeDevice(train_device); + const auto step_end_time = std::chrono::steady_clock::now(); + const double duration_ms = std::chrono::duration(step_end_time - step_start_time).count(); + step_duration_ms.push_back(duration_ms); + + if (FLAGS_log_interval > 0 && ((step + 1) % FLAGS_log_interval == 0 || step + 1 == FLAGS_num_iteration)) { + auto loss_cpu = loss->To(Device()); + const float lossf = static_cast(loss_cpu.DataPtr())[0]; + std::cout << std::format( + "step {:4d}/{} | train loss {:.6f} | norm -1.0000 | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s)", step + 1, + FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_ms, tokens_per_step / (duration_ms / 1e3)) + << std::endl; + } + } + if (!step_duration_ms.empty()) { + double duration_sum_ms = 0.0; + for (size_t idx = step_duration_ms.size() > 1 ? 1 : 0; idx < step_duration_ms.size(); ++idx) { + duration_sum_ms += step_duration_ms[idx]; + } + const size_t averaged_steps + = step_duration_ms.size() > 1 ? step_duration_ms.size() - 1 : step_duration_ms.size(); + std::cout << std::format("final {} iters avg: {:.3f}ms", averaged_steps, duration_sum_ms / averaged_steps) + << std::endl; + } + + gflags::ShutDownCommandLineFlags(); + google::ShutdownGoogleLogging(); + return 0; +} diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index d2de5e74..713ce58f 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -31,7 +31,7 @@ struct MoEConfig { kSigmoid, }; - enum class DispatcherType { + enum class TokenDispatcherType { kAllGather, // Megatron-style AllGather dispatcher. Degenerates to local dispatch when TP=EP=1. kAllToAll // Megatron-style AllToAll dispatcher for expert parallel MoE. }; @@ -50,7 +50,7 @@ struct MoEConfig { std::optional expert_capacity_factor = std::nullopt; bool pad_expert_input_to_capacity = false; int64_t moe_ffn_hidden_size = 0; - DispatcherType dispatcher_type = DispatcherType::kAllGather; + TokenDispatcherType token_dispatcher_type = TokenDispatcherType::kAllGather; ExpertImpl expert_impl = ExpertImpl::kSequential; }; diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc index 7566c48f..fa8681da 100644 --- a/infini_train/src/nn/modules/transformer/moe/experts.cc +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -19,7 +19,7 @@ SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule( CHECK(moe_config.expert_impl == MoEConfig::ExpertImpl::kSequential); CHECK_EQ(moe_config.expert_parallel_size, 1) << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; - CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + CHECK(moe_config.token_dispatcher_type == MoEConfig::TokenDispatcherType::kAllGather) << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; num_local_experts_ = moe_config.num_experts; diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc index 6add37ef..1e15fe81 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -15,7 +15,7 @@ namespace infini_train::nn::moe { MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); CHECK(config_.ffn_type == FFNType::kMoE); - CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + CHECK(moe_config.token_dispatcher_type == MoEConfig::TokenDispatcherType::kAllGather) << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; modules_[kRouterLayerName] = std::make_shared(config_); From 6405669727ad72b7b912c962d7c5e94519f8fd3b Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 8 Jun 2026 09:42:44 +0000 Subject: [PATCH 9/9] test: integrate tiny_mixtral into automated test pipeline --- example/tiny_mixtral/config.h | 10 ++-- example/tiny_mixtral/main.cc | 55 +++++++++++++----- scripts/run_models_and_profile.bash | 88 ++++++++++++++++++++++++----- scripts/test_config.json | 50 ++++++++++++++-- 4 files changed, 164 insertions(+), 39 deletions(-) diff --git a/example/tiny_mixtral/config.h b/example/tiny_mixtral/config.h index 0d7096d4..5293fa93 100644 --- a/example/tiny_mixtral/config.h +++ b/example/tiny_mixtral/config.h @@ -13,10 +13,10 @@ inline nn::TransformerConfig TinyMixtralConfig() { config.block_size = 32768; // Same as Mixtral/Megatron --max-position-embeddings. config.vocab_size = 128256; // Validation data uses LLaMA3 token ids; real Mixtral uses 32000. config.original_vocab_size = 128256; - config.n_layer = 2; // Tiny scale; Megatron --num-layers 32. - config.n_head = 4; // Tiny scale; preserves the Megatron 4:1 GQA ratio. - config.n_kv_head = 1; // Tiny scale; Megatron --num-query-groups 8. - config.n_embd = 32; // Tiny scale; Megatron --hidden-size 4096. + config.n_layer = 32; + config.n_head = 4; // Scaled down; preserves Mixtral 4:1 GQA ratio. + config.n_kv_head = 1; // Scaled down with n_head; real Mixtral uses 8 KV heads. + config.n_embd = 512; // Scaled down from Mixtral/Megatron --hidden-size 4096. config.attention_type = nn::AttentionType::kRoPE; config.activation_type = nn::MLPType::kSwiGLU; config.ffn_type = nn::FFNType::kMoE; @@ -33,7 +33,7 @@ inline nn::TransformerConfig TinyMixtralConfig() { moe_config.num_experts = 8; moe_config.expert_parallel_size = 1; // Single-rank validation scale. moe_config.router_topk = 2; - moe_config.moe_ffn_hidden_size = 112; // Tiny scale; Megatron --ffn-hidden-size 14336. + moe_config.moe_ffn_hidden_size = 1792; // Scaled down as 512 * 3.5; real Mixtral uses 14336. moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; // Single-rank validation path. moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; // Local correctness path. config.moe_config = moe_config; diff --git a/example/tiny_mixtral/main.cc b/example/tiny_mixtral/main.cc index dc2b5136..e38d0346 100644 --- a/example/tiny_mixtral/main.cc +++ b/example/tiny_mixtral/main.cc @@ -11,6 +11,7 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/tiny_mixtral/checkpoint_loader.h" #include "example/tiny_mixtral/config.h" +#include "infini_train/include/autocast.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" @@ -21,13 +22,15 @@ #include "infini_train/include/tensor.h" DEFINE_string(input_bin, "", "input .bin to train on"); -DEFINE_uint32(batch_size, 4, "batch size"); +DEFINE_uint32(micro_batch_size, 4, "micro batch size per training step"); +DEFINE_uint32(global_batch_size, 4, "global batch size across gradient accumulation and data parallelism"); DEFINE_uint32(sequence_length, 64, "sequence length"); DEFINE_uint32(num_iteration, 10, "number of training iterations"); DEFINE_double(learning_rate, 1e-4, "SGD learning rate"); DEFINE_string(llmc_filepath, "", "optional PyTorch-generated tiny Mixtral LLMC model file path to load before training"); DEFINE_string(device, "cpu", "Training device: cpu or cuda."); +DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); DEFINE_uint32(log_interval, 1, "Print train loss every N steps. 0 disables step loss logging."); DEFINE_bool(print_timing, false, "Print training-loop elapsed time and token throughput."); @@ -36,9 +39,15 @@ namespace { using infini_train::Device; using infini_train::Tensor; +constexpr char kDtypeFP32[] = "float32"; +constexpr char kDtypeBF16[] = "bfloat16"; + void ValidateRuntimeFlags(const infini_train::nn::TransformerConfig &config) { CHECK(!FLAGS_input_bin.empty()) << "tiny Mixtral training requires --input_bin"; - CHECK_GT(FLAGS_batch_size, 0); + CHECK_GT(FLAGS_micro_batch_size, 0); + CHECK_GT(FLAGS_global_batch_size, 0); + CHECK_EQ(FLAGS_global_batch_size % FLAGS_micro_batch_size, 0) + << "global_batch_size must be divisible by micro_batch_size"; CHECK_GT(FLAGS_sequence_length, 0); CHECK_LE(FLAGS_sequence_length, config.block_size) << "sequence_length must be <= model max positions (block_size)"; } @@ -76,10 +85,19 @@ int main(int argc, char *argv[]) { } infini_train::DistributedDataLoader train_loader( - std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size, + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_micro_batch_size, /*ddp_rank=*/0, /*ddp_world_size=*/1); auto train_iter = train_loader.begin(); + infini_train::DataType dtype; + if (FLAGS_dtype == kDtypeFP32) { + dtype = infini_train::DataType::kFLOAT32; + } else if (FLAGS_dtype == kDtypeBF16) { + dtype = infini_train::DataType::kBFLOAT16; + } else { + LOG(FATAL) << "Datatype " << FLAGS_dtype << " not supported."; + } + auto loss_fn = std::make_shared(); auto optimizer = infini_train::optimizers::SGD::Create(static_cast(FLAGS_learning_rate))(model->Parameters()); @@ -87,22 +105,31 @@ int main(int argc, char *argv[]) { auto device_impl = infini_train::core::GetDeviceGuardImpl(train_device.type()); std::vector step_duration_ms; step_duration_ms.reserve(FLAGS_num_iteration); - const double tokens_per_step = static_cast(FLAGS_batch_size) * FLAGS_sequence_length; + const uint32_t grad_accum_steps = FLAGS_global_batch_size / FLAGS_micro_batch_size; + const double tokens_per_step = static_cast(FLAGS_global_batch_size) * FLAGS_sequence_length; for (uint32_t step = 0; step < FLAGS_num_iteration; ++step) { device_impl->SynchronizeDevice(train_device); const auto step_start_time = std::chrono::steady_clock::now(); optimizer->ZeroGrad(); - if (train_iter == train_loader.end()) { - train_iter = train_loader.begin(); + float lossf = 0.0f; + for (uint32_t micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + infini_train::AutocastGuard autocast_guard(train_device.type(), dtype); + if (train_iter == train_loader.end()) { + train_iter = train_loader.begin(); + } + auto [x_cpu, y_cpu] = *train_iter; + ++train_iter; + auto x = std::make_shared(x_cpu->To(train_device)); + auto y = std::make_shared(y_cpu->To(train_device)); + auto logits = (*model)({x})[0]; + auto loss = (*loss_fn)({logits, y})[0]; + auto loss_cpu = loss->To(Device()); + lossf += static_cast(loss_cpu.DataPtr())[0] / grad_accum_steps; + loss = loss / static_cast(grad_accum_steps); + autocast_guard.Disable(); + loss->Backward(); } - auto [x_cpu, y_cpu] = *train_iter; - ++train_iter; - auto x = std::make_shared(x_cpu->To(train_device)); - auto y = std::make_shared(y_cpu->To(train_device)); - auto logits = (*model)({x})[0]; - auto loss = (*loss_fn)({logits, y})[0]; - loss->Backward(); optimizer->Step(); device_impl->SynchronizeDevice(train_device); @@ -111,8 +138,6 @@ int main(int argc, char *argv[]) { step_duration_ms.push_back(duration_ms); if (FLAGS_log_interval > 0 && ((step + 1) % FLAGS_log_interval == 0 || step + 1 == FLAGS_num_iteration)) { - auto loss_cpu = loss->To(Device()); - const float lossf = static_cast(loss_cpu.DataPtr())[0]; std::cout << std::format( "step {:4d}/{} | train loss {:.6f} | norm -1.0000 | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s)", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_ms, tokens_per_step / (duration_ms / 1e3)) diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index e3c67293..351e6755 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -66,12 +66,15 @@ read_var() { jq -r --arg k "$key" '.variables[$k] // empty' "$CONFIG_FILE" } -BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" -LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" -PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" -COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" -RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}" -CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}" +BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" +LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" +PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" +COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" +RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}" +CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}" +GPT2_TEST_GROUPS="$(read_var GPT2_TEST_GROUPS)"; : "${GPT2_TEST_GROUPS:=basic,zero,lora}" +LLAMA3_TEST_GROUPS="$(read_var LLAMA3_TEST_GROUPS)"; : "${LLAMA3_TEST_GROUPS:=basic,zero,lora}" +MIXTRAL_TEST_GROUPS="$(read_var MIXTRAL_TEST_GROUPS)"; : "${MIXTRAL_TEST_GROUPS:=moe}" mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR" @@ -219,6 +222,54 @@ args_string_for_test() { ' "$CONFIG_FILE" | paste -sd' ' - } +tag_enabled_for_model() { + local tag="$1" + local enabled_tags="$2" + + if [[ "$enabled_tags" == "*" ]]; then + return 0 + fi + + IFS=',' read -r -a tags <<< "$enabled_tags" + for raw_tag in "${tags[@]}"; do + local enabled_tag + enabled_tag="$(normalize_tag "$raw_tag")" + if [[ "$enabled_tag" == "$tag" ]]; then + return 0 + fi + done + return 1 +} + +model_has_selected_group() { + local enabled_tags="$1" + + for ((gi=0; gi&2 + exit 1 + fi + if [[ ! -f "$MIXTRAL_LLMC_FILEPATH" ]]; then + echo "Error: missing MIXTRAL_LLMC_FILEPATH: $MIXTRAL_LLMC_FILEPATH" >&2 + exit 1 + fi + fi +} + # Run tests num_builds=$(jq '.builds | length' "$CONFIG_FILE") num_groups=$(jq '.test_groups | length' "$CONFIG_FILE") @@ -226,7 +277,7 @@ num_groups=$(jq '.test_groups | length' "$CONFIG_FILE") selected_group_count=0 for ((gi=0; gi