Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "infer_engine.hpp"
#include "../config/config_factory.hpp"
#include "spdlog/spdlog.h"
#include <future>
#include <algorithm>
#include <stdexcept>
#include <string>

namespace infinilm::engine {

Expand All @@ -15,8 +17,16 @@ InferEngine::InferEngine(
const cache::CacheConfig *cache_config,
bool enable_graph_compiling,
backends::AttentionBackend attention_backend,
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
std::optional<infinicore::DataType> kv_cache_dtype,
const std::string &weight_load_mode) // Changed parameter
: communication_group_(distributed_config, device_type),
attention_backend_(attention_backend),
weight_load_mode_(weight_load_mode),
weight_load_group_size_(2),
weight_load_clone_(weight_load_mode == "grouped-clone") {
if (weight_load_mode_ != "sync" && weight_load_mode_ != "async" && weight_load_mode_ != "grouped" && weight_load_mode_ != "grouped-clone") {
throw std::invalid_argument("weight_load_mode must be one of: sync, async, grouped, grouped-clone");
}
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}
Expand Down Expand Up @@ -57,15 +67,32 @@ void InferEngine::load_param(const std::string &name, const infinicore::Tensor &
}

void InferEngine::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params) {
std::vector<std::future<void>> futures;
futures.reserve(workers_.size());
for (auto &worker : workers_) {
futures.emplace_back(std::async(std::launch::async, [&worker, &params] {
worker->load_params(params);
}));
if (workers_.size() <= 1 || weight_load_mode_ == "sync") {
for (auto &worker : workers_) {
worker->load_params(params, weight_load_clone_);
}
return;
}

if (weight_load_mode_ == "async") {
for (auto &worker : workers_) {
worker->load_params_async(params, weight_load_clone_);
}
for (auto &worker : workers_) {
worker->wait();
}
return;
}
for (auto &future : futures) {
future.get();

const size_t group_size = std::max<size_t>(1, std::min(weight_load_group_size_, workers_.size()));
for (size_t group_start = 0; group_start < workers_.size(); group_start += group_size) {
const size_t group_end = std::min(group_start + group_size, workers_.size());
for (size_t i = group_start; i < group_end; ++i) {
workers_[i]->load_params_async(params, weight_load_clone_);
}
for (size_t i = group_start; i < group_end; ++i) {
workers_[i]->wait();
}
}
}

Expand Down
7 changes: 6 additions & 1 deletion csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "rank_worker.hpp"

#include <optional>
#include <string>
#include <unordered_map>
#include <vector>

Expand All @@ -28,7 +29,8 @@ class InferEngine {
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt,
const std::string &weight_load_mode = "async");

// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
Expand Down Expand Up @@ -63,6 +65,9 @@ class InferEngine {
std::unique_ptr<cache::CacheConfig> cache_config_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
std::string weight_load_mode_ = "async";
size_t weight_load_group_size_ = 2;
bool weight_load_clone_ = false;
};

} // namespace infinilm::engine
65 changes: 55 additions & 10 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,39 @@
#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>

namespace infinilm::engine {

namespace {

infinicore::Tensor clone_tensor_for_weight_load(const infinicore::Tensor &tensor) {
auto cloned = infinicore::Tensor::empty(
tensor->shape(),
tensor->dtype(),
tensor->device(),
false);
cloned->copy_from(tensor);
return cloned;
}

std::unordered_map<std::string, infinicore::Tensor> clone_params_for_weight_load(
const std::unordered_map<std::string, infinicore::Tensor> &params,
bool clone_weights) {
if (!clone_weights) {
return params;
}

std::unordered_map<std::string, infinicore::Tensor> cloned_params;
cloned_params.reserve(params.size());
for (const auto &[name, tensor] : params) {
cloned_params.emplace(name, clone_tensor_for_weight_load(tensor));
}
return cloned_params;
}

} // namespace

RankWorker::RankWorker(
std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config,
const distributed::RankInfo &rank_info,
Expand Down Expand Up @@ -91,26 +121,37 @@ void RankWorker::load_param(const std::string &name,
//------------------------------------------------------
// load_params -- synchronous batch load
//------------------------------------------------------
void RankWorker::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params) {
void RankWorker::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights) {
load_params_async(params, clone_weights);

std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; });

if (should_exit_) {
throw std::runtime_error("RankWorker stopped while loading parameters");
}
}

//------------------------------------------------------
// load_params_async -- submit batch load without waiting
//------------------------------------------------------
void RankWorker::load_params_async(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot load_params");
throw std::runtime_error("RankWorker is closing; cannot load_params_async");
}
if (has_job_ && !job_done_) {
throw std::runtime_error("RankWorker already has a pending job");
}

pending_params_ = params;
pending_weight_load_clone_ = clone_weights;
job_cmd_ = Command::LOAD_BATCH;
has_job_ = true;
job_done_ = false;
}
cv_.notify_all();

std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; });

if (should_exit_) {
throw std::runtime_error("RankWorker stopped while loading parameters");
}
}

//------------------------------------------------------
Expand Down Expand Up @@ -292,6 +333,7 @@ void RankWorker::thread_loop() {
std::string local_param_name;
infinicore::Tensor local_param;
std::unordered_map<std::string, infinicore::Tensor> local_params;
bool local_weight_load_clone = false;
Input local_args;
std::unique_ptr<cache::CacheConfig> local_cache_config;

Expand All @@ -311,6 +353,8 @@ void RankWorker::thread_loop() {
local_param = pending_param_;
} else if (local_cmd == Command::LOAD_BATCH) {
local_params = std::move(pending_params_);
local_weight_load_clone = pending_weight_load_clone_;
pending_weight_load_clone_ = false;
pending_params_.clear();
} else if (local_cmd == Command::PREPROCESS) {

Expand Down Expand Up @@ -350,7 +394,8 @@ void RankWorker::thread_loop() {

} else if (local_cmd == Command::LOAD_BATCH) {
try {
model_->load_parameters_no_sync(local_params);
auto params_for_load = clone_params_for_weight_load(local_params, local_weight_load_clone);
model_->load_parameters_no_sync(params_for_load);
infinicore::context::syncStream();
} catch (const std::exception &e) {
{
Expand Down
5 changes: 4 additions & 1 deletion csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class RankWorker {
void load_param(const std::string &name,
const infinicore::Tensor &param);

void load_params(const std::unordered_map<std::string, infinicore::Tensor> &params);
void load_params(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights = false);

void load_params_async(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights = false);

void process_weights_after_loading();

Expand Down Expand Up @@ -144,6 +146,7 @@ class RankWorker {
std::string pending_param_name_;
infinicore::Tensor pending_param_;
std::unordered_map<std::string, infinicore::Tensor> pending_params_;
bool pending_weight_load_clone_ = false;
Input pending_args_;
std::unique_ptr<cache::CacheConfig> pending_cache_config_;

Expand Down
9 changes: 6 additions & 3 deletions csrc/pybind11/engine/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,26 @@ inline void bind_infer_engine(py::module &m) {
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling,
const std::string &attention_backend,
std::optional<infinicore::DataType> kv_cache_dtype) {
std::optional<infinicore::DataType> kv_cache_dtype,
const std::string &weight_load_mode) {
return std::make_shared<InferEngine>(
model_path,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend),
kv_cache_dtype);
kv_cache_dtype,
weight_load_mode);
}),
py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default",
py::arg("kv_cache_dtype") = py::none())
py::arg("kv_cache_dtype") = py::none(),
py::arg("weight_load_mode") = "async")
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
Expand Down
3 changes: 3 additions & 0 deletions examples/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
cache_config=None,
enable_graph=False,
attn_backend="default",
weight_load_mode="async",
) -> None:
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
Expand All @@ -187,6 +188,7 @@ def __init__(
enable_graph_compiling=enable_graph,
attention_backend=attn_backend,
kv_cache_dtype=cfg.kv_cache_dtype,
weight_load_mode=weight_load_mode,
)

# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -327,6 +329,7 @@ def run(
cache_config=cache_config,
enable_graph=enable_graph,
attn_backend=attn_backend,
weight_load_mode=cfg.weight_load_mode,
)

# ---------------------------------------------------------------------------- #
Expand Down
3 changes: 3 additions & 0 deletions examples/test_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test(
attn_backend="default",
image_path=None,
skip_load=False,
weight_load_mode="async",
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
Expand All @@ -39,6 +40,7 @@ def test(
enable_graph=enable_graph,
attn_backend=attn_backend,
skip_load=skip_load,
weight_load_mode=weight_load_mode,
)

conversations = [
Expand Down Expand Up @@ -103,4 +105,5 @@ def test(
attn_backend=cfg.attn,
image_path=cfg.image,
skip_load=cfg.skip_load,
weight_load_mode=cfg.weight_load_mode,
)
9 changes: 8 additions & 1 deletion python/infinilm/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class BaseConfig:
"""InfiniLM Unified Config - Command line argument parser"""

def __init__(self):

self.parser = argparse.ArgumentParser(description="InfiniLM Unified Config")
self._add_common_args()
self.args, self.extra = self.parser.parse_known_args()
Expand All @@ -67,6 +66,7 @@ def __init__(self):
self.max_cache_len = self.args.max_cache_len
self.kv_cache_dtype = self.args.kv_cache_dtype
self.skip_load = self.args.skip_load
self.weight_load_mode = self.args.weight_load_mode

self.batch_size = self.args.batch_size
self.max_batch_size = self.args.max_batch_size
Expand Down Expand Up @@ -146,6 +146,13 @@ def _add_common_args(self):
self.parser.add_argument(
"--skip-load", action="store_true", help="skip loading model weights"
)
self.parser.add_argument(
"--weight-load-mode",
type=str,
default="async",
choices=["async", "sync", "grouped", "grouped-clone"],
help="weight loading mode: async keeps old behavior; grouped-clone is the stable 103B option",
)

# --- Length and infer parameters ---
self.parser.add_argument("--batch-size", type=int, default=1)
Expand Down
2 changes: 2 additions & 0 deletions python/infinilm/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
enable_graph_compiling=False,
attention_backend="default",
kv_cache_dtype=None,
weight_load_mode="async",
):
self.hf_config = read_hf_config(model_path)
self.hf_generation_config = read_hf_generation_config(model_path)
Expand All @@ -87,6 +88,7 @@ def __init__(
if kv_cache_dtype is not None
else None
),
weight_load_mode,
)
self.use_cache = False

Expand Down
2 changes: 2 additions & 0 deletions python/infinilm/llm/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def allocate_blocks(
"""
if block_table is None:
block_table = []
if mm_token_index_mappings is None:
mm_token_index_mappings = []

# Static args
num_tokens = len(token_ids)
Expand Down
Loading