-
Notifications
You must be signed in to change notification settings - Fork 45
fix: correct LoRA initialization and forward pass under tensor parallelism #150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
a7df3a6
9467166
f4f2220
adbc82d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| #include "infini_train/include/nn/lora/lora_parallel_linear.h" | ||
| #include "infini_train/include/nn/modules/linear.h" | ||
| #include "infini_train/include/nn/modules/module.h" | ||
| #include "infini_train/include/nn/parallel/global.h" | ||
| #include "infini_train/include/nn/parallel/tensor_parallel.h" | ||
| #include "infini_train/include/tensor.h" | ||
|
|
||
|
|
@@ -392,10 +393,30 @@ void LoadLoRAWeights(std::shared_ptr<Module> model, const std::string &filepath) | |
| auto cpu_tensor = std::make_shared<Tensor>(dims, DataType::kFLOAT32, Device(Device::DeviceType::kCPU, 0)); | ||
| file.read(reinterpret_cast<char *>(cpu_tensor->DataPtr()), num_elements * sizeof(float)); | ||
|
|
||
| // Load into model | ||
| // Load into model, slicing sharded tensors by tp_rank if shapes differ | ||
| auto it = model_state_dict.find(name); | ||
| if (it != model_state_dict.end()) { | ||
| it->second->CopyFrom(cpu_tensor); | ||
| auto &dst = it->second; | ||
| const auto &dst_dims = dst->Dims(); | ||
| if (dst_dims == dims) { | ||
| dst->CopyFrom(cpu_tensor); | ||
| } else { | ||
| // Determine which dim is sharded: find first dim where sizes differ | ||
| int shard_dim = -1; | ||
| for (int d = 0; d < static_cast<int>(dims.size()); ++d) { | ||
| if (d < static_cast<int>(dst_dims.size()) && dst_dims[d] != dims[d]) { | ||
| shard_dim = d; | ||
| break; | ||
| } | ||
| } | ||
| CHECK(shard_dim >= 0) << "LoadLoRAWeights: shape mismatch for " << name | ||
| << " but no differing dim found"; | ||
| int tp_size = parallel::global::GetTensorParallelSize(); | ||
| int64_t shard_size = dims[shard_dim] / tp_size; | ||
| int64_t start = parallel::tp_rank * shard_size; | ||
| auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size); | ||
| dst->CopyFrom(sliced); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 现在模型实现中,Attention 里面第一个 Linear 把 QKV 合并了(实际排列上是 [Q | K | V]),LoadFromLLMC 里面有对应的处理(如 TP=4,则切成 [Q0 Q1 Q2 Q3 | K0 K1 K2 K3 | V0 V1 V2 V3],然后依次拼接 [Qi | Ki | Vi] load 到每个 tp rank 上)。所以这块的切分逻辑对于这种情况没法适用,得想想看怎么能特别处理一下 |
||
| } | ||
| } else { | ||
| LOG(WARNING) << "LoRA parameter not found in model: " << name; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这里做各个 tp rank 各自 init,可能会导致 DP 组之间无法达到权重一致的结果。现在的情况可能会因为都用默认 seed 导致恰好能对上,但是从代码看,语义上还是各初始化各的,DP 组之间有可能权重对不齐。原则上应该保证 DDP model 最后再做一步参数的广播/复制,从而确保在实际执行 forward 前各个能对应上的 dp rank 上面模型参数是相同的,但这里我知道改的话会不会影响其他 lora 基建;最简单的就是
tp_group->Broadcast({parameters_[kParamLoraAName]}, 0);后面再加个 dp_group 的 broadcast。