Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
776 changes: 181 additions & 595 deletions docs/lora_usage_guide.md

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions infini_train/include/nn/parallel/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ class ProcessGroup {
function::ReduceOpType reduce_op = function::ReduceOpType::kSum,
bool async_op = false) const;

// root_rank_in_group is ProcessGroup-local rank. Broadcast updates tensors in place.
virtual std::shared_ptr<Work> Broadcast(const std::vector<std::shared_ptr<Tensor>> &tensors, int root_rank_in_group,
bool async_op = false) const;

// Root provides rank-major input_tensors: rank * output_tensors.size() + tensor_index.
virtual std::shared_ptr<Work> Scatter(const std::vector<std::shared_ptr<Tensor>> &output_tensors,
const std::vector<std::shared_ptr<Tensor>> &input_tensors,
int root_rank_in_group, bool async_op = false) const;

virtual std::shared_ptr<Work> Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank,
bool async_op = false) const;

Expand All @@ -60,13 +69,13 @@ class ProcessGroup {

// Legacy communication APIs (Single-stream)
virtual std::vector<std::shared_ptr<Tensor>>
BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const;
BroadCast_(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const;

virtual std::vector<std::shared_ptr<Tensor>>
ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads, Device destination) const;

virtual std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor,
std::vector<Device> devices, int64_t dim) const;
virtual std::vector<std::shared_ptr<Tensor>> Scatter_(const std::shared_ptr<Tensor> &tensor,
std::vector<Device> devices, int64_t dim) const;

virtual std::shared_ptr<Tensor> Gather(const std::vector<std::shared_ptr<Tensor>> &tensors, Device destination,
int64_t dim) const;
Expand Down
4 changes: 2 additions & 2 deletions infini_train/src/autograd/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ std::vector<std::shared_ptr<Tensor>> Scatter::Forward(const std::vector<std::sha
const auto &input = input_tensors[0];
std::vector<std::shared_ptr<Tensor>> output_tensors;
auto device = input->GetDevice().type();
output_tensors = pg_->Scatter(input, target_gpus_, dim_);
output_tensors = pg_->Scatter_(input, target_gpus_, dim_);
return output_tensors;
}

Expand Down Expand Up @@ -83,7 +83,7 @@ std::vector<std::shared_ptr<Tensor>> Broadcast::Forward(const std::vector<std::s
}

// TODO(dcj): mark non differentiable
return pg_->BroadCast(input_tensors);
return pg_->BroadCast_(input_tensors);
}

void Broadcast::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
Expand Down
161 changes: 90 additions & 71 deletions infini_train/src/nn/lora/lora_parallel_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infini_train/include/nn/init.h"
#include "infini_train/include/nn/modules/linear.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/tensor.h"
Expand Down Expand Up @@ -89,22 +90,34 @@ LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr<parallel::Col
}

void LoRAColumnParallelLinear::InitLoRAWeights() {
// LoRA weights stored directly in parameters_
// Following PEFT pattern conceptually:
// lora_A: [rank, in_features] - replicated
// lora_A: [rank, in_features] - replicated across TP ranks
// lora_B: [out_features_per_partition, rank] - sharded like base weight

// lora_A: [rank, in_features]
parameters_[kParamLoraAName]
= std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_}, DataType::kFLOAT32, device_)
->RequiresGrad();
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);

if (parallel::global::GetTensorParallelSize() > 1) {
const auto global_rank = device_.Rank().GlobalRank();
auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type())
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
const int tp_rank = tp_group->GetGroupRank(global_rank);

if (tp_rank == 0) {
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这里做各个 tp rank 各自 init,可能会导致 DP 组之间无法达到权重一致的结果。现在的情况可能会因为都用默认 seed 导致恰好能对上,但是从代码看,语义上还是各初始化各的,DP 组之间有可能权重对不齐。原则上应该保证 DDP model 最后再做一步参数的广播/复制,从而确保在实际执行 forward 前各个能对应上的 dp rank 上面模型参数是相同的,但这里我知道改的话会不会影响其他 lora 基建;最简单的就是 tp_group->Broadcast({parameters_[kParamLoraAName]}, 0); 后面再加个 dp_group 的 broadcast。

} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
}
}
tp_group->Broadcast({parameters_[kParamLoraAName]}, 0);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
}
}

// lora_B: [out_per_partition, rank] - sharded like base weight
parameters_[kParamLoraBName]
= std::make_shared<Tensor>(std::vector<int64_t>{out_features_per_partition_, config_.rank}, DataType::kFLOAT32,
device_)
Expand All @@ -126,39 +139,35 @@ LoRAColumnParallelLinear::Forward(const std::vector<std::shared_ptr<Tensor>> &in
<< "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training.";

if (!merged_) {
// 1. Compute base output via parent class
auto base_result = ColumnParallelLinear::Forward(input_tensors);
auto base_output = base_result[0];

// 2. Compute LoRA output using the SAME input that base module uses
// Match base input path exactly: use direct input if input_is_parallel_ or sequence_parallel_,
// otherwise copy to TP region
auto lora_input = (input_is_parallel_ || sequence_parallel_)
? input_tensors[0]
: parallel::CopyToTPRegionFunc(input_tensors[0])[0];
// Inline base + LoRA matmuls, add locally, then single collective op.
// This avoids 2 separate AllGather ops which cause floating-point divergence.
auto input = (input_is_parallel_ || sequence_parallel_) ? input_tensors[0]
: parallel::CopyToTPRegionFunc(input_tensors[0])[0];
if (sequence_parallel_) {
// Base uses GatherFromSPRegionFunc to gather sequence dimension
lora_input = parallel::GatherFromSPRegionFunc(lora_input)[0];
input = parallel::GatherFromSPRegionFunc(input)[0];
}

// Compute LoRA: lora_A: [rank, in_features], lora_B: [out_per_partition, rank]
auto lora_proj = std::make_shared<autograd::Linear>()->Apply({lora_input, parameters_[kParamLoraAName]})[0];
// Base matmul (bias folded in when applicable, matching ColumnParallelLinear::Forward)
auto base_shard = std::make_shared<autograd::Linear>()->Apply(
(bias_ && !skip_bias_add_)
? std::vector<std::shared_ptr<Tensor>>{input, parameters_.at(kParamWeightName),
parameters_[kParamBiasName]}
: std::vector<std::shared_ptr<Tensor>>{input, parameters_.at(kParamWeightName)})[0];

// LoRA matmul (local)
// Wrap replicated lora_A through CopyToTPRegion so its gradient gets AllReduced in backward
auto lora_A = parallel::CopyToTPRegionFunc(parameters_[kParamLoraAName])[0];
auto lora_proj = std::make_shared<autograd::Linear>()->Apply({input, lora_A})[0];
auto lora_output = std::make_shared<autograd::Linear>()->Apply({lora_proj, parameters_[kParamLoraBName]})[0];

// Match base output layout (gather if base gathers)
if (gather_output_) {
lora_output = parallel::GatherFromTPRegionFunc(lora_output)[0];
}

auto scaled_lora = lora_output->Mul(config_.Scaling());
// Local add before collective
auto combined = base_shard->Add(lora_output->Mul(config_.Scaling()));

// 3. Add LoRA contribution to base output
// Both should now have the same sequence dimension
auto output = base_output->Add(scaled_lora);
// Single collective op
auto output = gather_output_ ? parallel::GatherFromTPRegionFunc(combined)[0] : combined;

// Return in same format as base module
return skip_bias_add_
? std::vector<std::shared_ptr<Tensor>>{output, bias_ ? parameters_[kParamBiasName] : nullptr}
? std::vector<std::shared_ptr<Tensor>>{output, bias_ ? parameters_.at(kParamBiasName) : nullptr}
: std::vector<std::shared_ptr<Tensor>>{output};
}

Expand Down Expand Up @@ -294,10 +303,30 @@ void LoRARowParallelLinear::InitLoRAWeights() {
= std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_per_partition_}, DataType::kFLOAT32,
device_)
->RequiresGrad();
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
if (parallel::global::GetTensorParallelSize() > 1) {
const auto global_rank = device_.Rank().GlobalRank();
auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type())
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
const int tp_rank = tp_group->GetGroupRank(global_rank);

std::vector<std::shared_ptr<Tensor>> scatter_inputs;
if (tp_rank == 0) {
auto full_lora_A = std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_},
DataType::kFLOAT32, device_);
if (config_.use_kaiming_a) {
init::KaimingUniform(full_lora_A, config_.kaiming_a_param);
} else {
init::Normal(full_lora_A, 0.0f, 0.02f);
}
scatter_inputs = full_lora_A->Split(in_features_per_partition_, 1);
}
tp_group->Scatter({parameters_[kParamLoraAName]}, scatter_inputs, 0);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
}
}

// lora_B: [out_features, rank]
Expand All @@ -321,42 +350,32 @@ LoRARowParallelLinear::Forward(const std::vector<std::shared_ptr<Tensor>> &input
<< "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training.";

if (!merged_) {
// Get effective input - match what base module uses
auto effective_input = input_tensors[0];
const int64_t in_dim = effective_input->Dims().back();

if (!input_is_parallel_) {
// base would scatter; lora must match
effective_input = parallel::ScatterToTPRegionFunc(effective_input)[0];
CHECK_EQ(effective_input->Dims().back(), in_features_per_partition_);
} else {
// input_is_parallel_=true means caller promised shard input
CHECK_EQ(in_dim, in_features_per_partition_)
<< "RowParallel expects sharded input when input_is_parallel_=true. "
<< "Got full in_dim=" << in_dim << " (likely upstream gathered TP output).";
// Inline base + LoRA matmuls, add locally, then single collective op.
// This avoids 2 separate AllReduce ops which cause floating-point divergence.
auto input = input_is_parallel_ ? input_tensors[0] : parallel::ScatterToTPRegionFunc(input_tensors[0])[0];

// Base matmul (no bias — RowParallel adds bias AFTER collective)
auto base_shard = std::make_shared<autograd::Linear>()->Apply({input, parameters_.at(kParamWeightName)})[0];

// LoRA matmul (local)
// Wrap replicated lora_B through CopyToTPRegion so its gradient gets AllReduced in backward
auto lora_proj = std::make_shared<autograd::Linear>()->Apply({input, parameters_[kParamLoraAName]})[0];
auto lora_B = parallel::CopyToTPRegionFunc(parameters_[kParamLoraBName])[0];
auto lora_output = std::make_shared<autograd::Linear>()->Apply({lora_proj, lora_B})[0];

// Local add before collective
auto combined = base_shard->Add(lora_output->Mul(config_.Scaling()));

// Single collective op
auto output = reduce_output_ ? (sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(combined)[0]
: parallel::ReduceFromTPRegionFunc(combined)[0])
: combined;

// Bias after collective (matching RowParallelLinear::Forward)
if (bias_ && !skip_bias_add_) {
output = output->Add(parameters_[kParamBiasName]);
}

// 1) base output - use effective_input
auto base_result = RowParallelLinear::Forward({effective_input});
auto base_output = base_result[0];

// 2) lora branch uses the SAME effective_input
auto lora_proj
= std::make_shared<autograd::Linear>()->Apply({effective_input, parameters_[kParamLoraAName]})[0];
auto lora_output = std::make_shared<autograd::Linear>()->Apply({lora_proj, parameters_[kParamLoraBName]})[0];

// 3) apply same reduction as base
auto lora_out = lora_output;
if (reduce_output_) {
lora_out = sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(lora_out)[0]
: parallel::ReduceFromTPRegionFunc(lora_out)[0];
}

auto scaled_lora = lora_out->Mul(config_.Scaling());
CHECK_EQ(base_output->NumElements(), scaled_lora->NumElements());
auto output = base_output->Add(scaled_lora);

// Return in same format as base module
return skip_bias_add_
? std::vector<std::shared_ptr<Tensor>>{output, bias_ ? parameters_[kParamBiasName] : nullptr}
: std::vector<std::shared_ptr<Tensor>>{output};
Expand Down
25 changes: 23 additions & 2 deletions infini_train/src/nn/lora/lora_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "infini_train/include/nn/lora/lora_parallel_linear.h"
#include "infini_train/include/nn/modules/linear.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

Expand Down Expand Up @@ -392,10 +393,30 @@ void LoadLoRAWeights(std::shared_ptr<Module> model, const std::string &filepath)
auto cpu_tensor = std::make_shared<Tensor>(dims, DataType::kFLOAT32, Device(Device::DeviceType::kCPU, 0));
file.read(reinterpret_cast<char *>(cpu_tensor->DataPtr()), num_elements * sizeof(float));

// Load into model
// Load into model, slicing sharded tensors by tp_rank if shapes differ
auto it = model_state_dict.find(name);
if (it != model_state_dict.end()) {
it->second->CopyFrom(cpu_tensor);
auto &dst = it->second;
const auto &dst_dims = dst->Dims();
if (dst_dims == dims) {
dst->CopyFrom(cpu_tensor);
} else {
// Determine which dim is sharded: find first dim where sizes differ
int shard_dim = -1;
for (int d = 0; d < static_cast<int>(dims.size()); ++d) {
if (d < static_cast<int>(dst_dims.size()) && dst_dims[d] != dims[d]) {
shard_dim = d;
break;
}
}
CHECK(shard_dim >= 0) << "LoadLoRAWeights: shape mismatch for " << name
<< " but no differing dim found";
int tp_size = parallel::global::GetTensorParallelSize();
int64_t shard_size = dims[shard_dim] / tp_size;
int64_t start = parallel::tp_rank * shard_size;
auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size);
dst->CopyFrom(sliced);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在模型实现中,Attention 里面第一个 Linear 把 QKV 合并了(实际排列上是 [Q | K | V]),LoadFromLLMC 里面有对应的处理(如 TP=4,则切成 [Q0 Q1 Q2 Q3 | K0 K1 K2 K3 | V0 V1 V2 V3],然后依次拼接 [Qi | Ki | Vi] load 到每个 tp rank 上)。所以这块的切分逻辑对于这种情况没法适用,得想想看怎么能特别处理一下

}
} else {
LOG(WARNING) << "LoRA parameter not found in model: " << name;
}
Expand Down
Loading
Loading