diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 77d07df4..85d7af07 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -1,7 +1,9 @@ #include "infer_engine.hpp" #include "../config/config_factory.hpp" #include "spdlog/spdlog.h" -#include +#include +#include +#include namespace infinilm::engine { @@ -15,8 +17,16 @@ InferEngine::InferEngine( const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend, - std::optional kv_cache_dtype) // Changed parameter - : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { + std::optional kv_cache_dtype, + const std::string &weight_load_mode) // Changed parameter + : communication_group_(distributed_config, device_type), + attention_backend_(attention_backend), + weight_load_mode_(weight_load_mode), + weight_load_group_size_(2), + weight_load_clone_(weight_load_mode == "grouped-clone") { + if (weight_load_mode_ != "sync" && weight_load_mode_ != "async" && weight_load_mode_ != "grouped" && weight_load_mode_ != "grouped-clone") { + throw std::invalid_argument("weight_load_mode must be one of: sync, async, grouped, grouped-clone"); + } if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); } @@ -57,15 +67,32 @@ void InferEngine::load_param(const std::string &name, const infinicore::Tensor & } void InferEngine::load_params(const std::unordered_map ¶ms) { - std::vector> futures; - futures.reserve(workers_.size()); - for (auto &worker : workers_) { - futures.emplace_back(std::async(std::launch::async, [&worker, ¶ms] { - worker->load_params(params); - })); + if (workers_.size() <= 1 || weight_load_mode_ == "sync") { + for (auto &worker : workers_) { + worker->load_params(params, weight_load_clone_); + } + return; + } + + if (weight_load_mode_ == "async") { + for (auto &worker : workers_) { + worker->load_params_async(params, weight_load_clone_); + } + for (auto &worker : workers_) { + worker->wait(); + } + return; } - for (auto &future : futures) { - future.get(); + + const size_t group_size = std::max(1, std::min(weight_load_group_size_, workers_.size())); + for (size_t group_start = 0; group_start < workers_.size(); group_start += group_size) { + const size_t group_end = std::min(group_start + group_size, workers_.size()); + for (size_t i = group_start; i < group_end; ++i) { + workers_[i]->load_params_async(params, weight_load_clone_); + } + for (size_t i = group_start; i < group_end; ++i) { + workers_[i]->wait(); + } } } diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 70c3c164..e69b07e1 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -9,6 +9,7 @@ #include "rank_worker.hpp" #include +#include #include #include @@ -28,7 +29,8 @@ class InferEngine { const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, - std::optional kv_cache_dtype = std::nullopt); + std::optional kv_cache_dtype = std::nullopt, + const std::string &weight_load_mode = "async"); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); @@ -63,6 +65,9 @@ class InferEngine { std::unique_ptr cache_config_; std::shared_ptr model_config_; backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default; + std::string weight_load_mode_ = "async"; + size_t weight_load_group_size_ = 2; + bool weight_load_clone_ = false; }; } // namespace infinilm::engine diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 87568fd6..63a1e463 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -7,9 +7,39 @@ #include #include #include +#include namespace infinilm::engine { +namespace { + +infinicore::Tensor clone_tensor_for_weight_load(const infinicore::Tensor &tensor) { + auto cloned = infinicore::Tensor::empty( + tensor->shape(), + tensor->dtype(), + tensor->device(), + false); + cloned->copy_from(tensor); + return cloned; +} + +std::unordered_map clone_params_for_weight_load( + const std::unordered_map ¶ms, + bool clone_weights) { + if (!clone_weights) { + return params; + } + + std::unordered_map cloned_params; + cloned_params.reserve(params.size()); + for (const auto &[name, tensor] : params) { + cloned_params.emplace(name, clone_tensor_for_weight_load(tensor)); + } + return cloned_params; +} + +} // namespace + RankWorker::RankWorker( std::shared_ptr infinilm_config, const distributed::RankInfo &rank_info, @@ -91,26 +121,37 @@ void RankWorker::load_param(const std::string &name, //------------------------------------------------------ // load_params -- synchronous batch load //------------------------------------------------------ -void RankWorker::load_params(const std::unordered_map ¶ms) { +void RankWorker::load_params(const std::unordered_map ¶ms, bool clone_weights) { + load_params_async(params, clone_weights); + + std::unique_lock lk(mutex_); + cv_.wait(lk, [&] { return job_done_ || should_exit_; }); + + if (should_exit_) { + throw std::runtime_error("RankWorker stopped while loading parameters"); + } +} + +//------------------------------------------------------ +// load_params_async -- submit batch load without waiting +//------------------------------------------------------ +void RankWorker::load_params_async(const std::unordered_map ¶ms, bool clone_weights) { { std::lock_guard lock(mutex_); if (should_exit_) { - throw std::runtime_error("RankWorker is closing; cannot load_params"); + throw std::runtime_error("RankWorker is closing; cannot load_params_async"); + } + if (has_job_ && !job_done_) { + throw std::runtime_error("RankWorker already has a pending job"); } pending_params_ = params; + pending_weight_load_clone_ = clone_weights; job_cmd_ = Command::LOAD_BATCH; has_job_ = true; job_done_ = false; } cv_.notify_all(); - - std::unique_lock lk(mutex_); - cv_.wait(lk, [&] { return job_done_ || should_exit_; }); - - if (should_exit_) { - throw std::runtime_error("RankWorker stopped while loading parameters"); - } } //------------------------------------------------------ @@ -292,6 +333,7 @@ void RankWorker::thread_loop() { std::string local_param_name; infinicore::Tensor local_param; std::unordered_map local_params; + bool local_weight_load_clone = false; Input local_args; std::unique_ptr local_cache_config; @@ -311,6 +353,8 @@ void RankWorker::thread_loop() { local_param = pending_param_; } else if (local_cmd == Command::LOAD_BATCH) { local_params = std::move(pending_params_); + local_weight_load_clone = pending_weight_load_clone_; + pending_weight_load_clone_ = false; pending_params_.clear(); } else if (local_cmd == Command::PREPROCESS) { @@ -350,7 +394,8 @@ void RankWorker::thread_loop() { } else if (local_cmd == Command::LOAD_BATCH) { try { - model_->load_parameters_no_sync(local_params); + auto params_for_load = clone_params_for_weight_load(local_params, local_weight_load_clone); + model_->load_parameters_no_sync(params_for_load); infinicore::context::syncStream(); } catch (const std::exception &e) { { diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index c214536f..ce0cdb71 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -85,7 +85,9 @@ class RankWorker { void load_param(const std::string &name, const infinicore::Tensor ¶m); - void load_params(const std::unordered_map ¶ms); + void load_params(const std::unordered_map ¶ms, bool clone_weights = false); + + void load_params_async(const std::unordered_map ¶ms, bool clone_weights = false); void process_weights_after_loading(); @@ -144,6 +146,7 @@ class RankWorker { std::string pending_param_name_; infinicore::Tensor pending_param_; std::unordered_map pending_params_; + bool pending_weight_load_clone_ = false; Input pending_args_; std::unique_ptr pending_cache_config_; diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 8e470984..21bb29ea 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -39,7 +39,8 @@ inline void bind_infer_engine(py::module &m) { std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, - std::optional kv_cache_dtype) { + std::optional kv_cache_dtype, + const std::string &weight_load_mode) { return std::make_shared( model_path, dist, @@ -47,7 +48,8 @@ inline void bind_infer_engine(py::module &m) { cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, infinilm::backends::parse_attention_backend(attention_backend), - kv_cache_dtype); + kv_cache_dtype, + weight_load_mode); }), py::arg("model_path") = "", py::arg("distributed_config") = distributed::DistConfig(), @@ -55,7 +57,8 @@ inline void bind_infer_engine(py::module &m) { py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, py::arg("attention_backend") = "default", - py::arg("kv_cache_dtype") = py::none()) + py::arg("kv_cache_dtype") = py::none(), + py::arg("weight_load_mode") = "async") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") diff --git a/examples/bench.py b/examples/bench.py index 018f7a8a..d9e27291 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -174,6 +174,7 @@ def __init__( cache_config=None, enable_graph=False, attn_backend="default", + weight_load_mode="async", ) -> None: model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -187,6 +188,7 @@ def __init__( enable_graph_compiling=enable_graph, attention_backend=attn_backend, kv_cache_dtype=cfg.kv_cache_dtype, + weight_load_mode=weight_load_mode, ) # ---------------------------------------------------------------------------- # @@ -327,6 +329,7 @@ def run( cache_config=cache_config, enable_graph=enable_graph, attn_backend=attn_backend, + weight_load_mode=cfg.weight_load_mode, ) # ---------------------------------------------------------------------------- # diff --git a/examples/test_infer.py b/examples/test_infer.py index a90a5d3e..30d7857e 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -18,6 +18,7 @@ def test( attn_backend="default", image_path=None, skip_load=False, + weight_load_mode="async", ): model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -39,6 +40,7 @@ def test( enable_graph=enable_graph, attn_backend=attn_backend, skip_load=skip_load, + weight_load_mode=weight_load_mode, ) conversations = [ @@ -103,4 +105,5 @@ def test( attn_backend=cfg.attn, image_path=cfg.image, skip_load=cfg.skip_load, + weight_load_mode=cfg.weight_load_mode, ) diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd45..63f121d3 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -44,7 +44,6 @@ class BaseConfig: """InfiniLM Unified Config - Command line argument parser""" def __init__(self): - self.parser = argparse.ArgumentParser(description="InfiniLM Unified Config") self._add_common_args() self.args, self.extra = self.parser.parse_known_args() @@ -67,6 +66,7 @@ def __init__(self): self.max_cache_len = self.args.max_cache_len self.kv_cache_dtype = self.args.kv_cache_dtype self.skip_load = self.args.skip_load + self.weight_load_mode = self.args.weight_load_mode self.batch_size = self.args.batch_size self.max_batch_size = self.args.max_batch_size @@ -146,6 +146,13 @@ def _add_common_args(self): self.parser.add_argument( "--skip-load", action="store_true", help="skip loading model weights" ) + self.parser.add_argument( + "--weight-load-mode", + type=str, + default="async", + choices=["async", "sync", "grouped", "grouped-clone"], + help="weight loading mode: async keeps old behavior; grouped-clone is the stable 103B option", + ) # --- Length and infer parameters --- self.parser.add_argument("--batch-size", type=int, default=1) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 75890b0f..614c4f67 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -67,6 +67,7 @@ def __init__( enable_graph_compiling=False, attention_backend="default", kv_cache_dtype=None, + weight_load_mode="async", ): self.hf_config = read_hf_config(model_path) self.hf_generation_config = read_hf_generation_config(model_path) @@ -87,6 +88,7 @@ def __init__( if kv_cache_dtype is not None else None ), + weight_load_mode, ) self.use_cache = False diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index ea857b02..fbe3d7b3 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -140,6 +140,8 @@ def allocate_blocks( """ if block_table is None: block_table = [] + if mm_token_index_mappings is None: + mm_token_index_mappings = [] # Static args num_tokens = len(token_ids) diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 953b35e5..002af17e 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -56,6 +56,7 @@ class EngineConfig: enable_graph: Whether to enable graph compiling. attn_backend: Attention backend to use ('default', 'flash-attn'). skip_load: Whether to skip loading model weights (for testing). + weight_load_mode: Weight loading mode across tensor-parallel ranks. """ model_path: str @@ -74,6 +75,7 @@ class EngineConfig: enable_graph: bool = False attn_backend: str = "default" skip_load: bool = False + weight_load_mode: str = "async" class LLMEngine: @@ -92,6 +94,7 @@ def __init__(self, config: EngineConfig): distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, attention_backend=config.attn_backend, + weight_load_mode=config.weight_load_mode, ) # Load model weights @@ -363,6 +366,7 @@ def __init__( enable_graph: bool = False, attn_backend: str = "default", skip_load: bool = False, + weight_load_mode: str = "async", ): """Initialize LLM. @@ -400,6 +404,7 @@ def __init__( enable_graph=enable_graph, attn_backend=attn_backend, skip_load=skip_load, + weight_load_mode=weight_load_mode, ) self.engine = LLMEngine(config) self.config = config @@ -553,6 +558,7 @@ def __init__( top_k: int = 1, enable_graph: bool = False, attn_backend: str = "default", + weight_load_mode: str = "async", ): """Initialize AsyncLLMEngine. @@ -589,6 +595,7 @@ def __init__( top_k=top_k, enable_graph=enable_graph, attn_backend=attn_backend, + weight_load_mode=weight_load_mode, ) self.engine = LLMEngine(config) self.config = config diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 9d7ff298..2fbe19dc 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -1,5 +1,6 @@ import os import json +import gc from typing import Dict, Union, Optional, List import time import torch @@ -105,11 +106,87 @@ def load_state_dict( ) for k in f.keys(): - state_dict[k] = f.get_tensor(k).to(device=device) + state_dict[k] = f.get_tensor(k).to(device=device, dtype=dtype) return state_dict +def iter_safetensors_tensors( + checkpoint_file: Union[str, os.PathLike], + dtype=torch.bfloat16, +): + if not str(checkpoint_file).endswith(".safetensors"): + return + + with safe_open(checkpoint_file, framework="pt", device="cpu") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in [ + "pt", + "tf", + "flax", + "mlx", + ]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata." + ) + + for k in f.keys(): + yield k, f.get_tensor(k).to(dtype=dtype) + + +def safetensors_has_key(file_list: List[str], key: str) -> bool: + for file_path in file_list: + with safe_open(file_path, framework="pt", device="cpu") as f: + if key in f.keys(): + return True + return False + + +def tensor_nbytes(tensor: torch.Tensor) -> int: + return tensor.numel() * tensor.element_size() + + +def weight_load_verbose() -> bool: + return os.getenv("INFINILM_WEIGHT_LOAD_VERBOSE", "0") not in ( + "", + "0", + "false", + "False", + ) + + +def describe_tensor_batch(tensors: Dict[str, torch.Tensor]) -> str: + total_bytes = sum(tensor_nbytes(tensor) for tensor in tensors.values()) + entries = [ + f"{name}{tuple(tensor.shape)}:{tensor_nbytes(tensor) / (1024**2):.1f}MiB" + for name, tensor in tensors.items() + ] + return f"{len(tensors)} tensors, {total_bytes / (1024**2):.1f}MiB, " + "; ".join( + entries + ) + + +def log_weight_batch(prefix: str, file_path: str, batch: Dict[str, torch.Tensor]): + if weight_load_verbose(): + print( + f"{prefix} from {os.path.basename(file_path)}: {describe_tensor_batch(batch)}", + flush=True, + ) + + +def load_model_tensor_batch( + model: infinicore.nn.Module, + tensors: Dict[str, torch.Tensor], +): + if not tensors: + return + model.load_state_dict( + {name: infinicore.from_torch(tensor) for name, tensor in tensors.items()}, + strict=False, + ) + infinicore.sync_device() + + def get_model_state_dict( model_path: str, device: infinicore.device, @@ -182,43 +259,113 @@ def load_model_state_dict_by_file( already_loaded_keys = [] embed_tokens_torch_unscaled = None - file_list = glob.glob(os.path.join(model_path, "*.safetensors")) + file_list = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) if len(file_list) > 0: + has_lm_head_weight = safetensors_has_key(file_list, "lm_head.weight") for file_path in tqdm(file_list, desc="Processing files"): tqdm.write(f"Processing: {os.path.basename(file_path)}") # --------------------------------------------------------- # # Load weights from *.safetensors file # --------------------------------------------------------- # + remapper = _WEIGHT_REMAPPER.get(model_type) + if remapper is None: + batch = {} + batch_nbytes = 0 + + def flush_batch(): + nonlocal batch, batch_nbytes + if not batch: + return + log_weight_batch("Loading batch", file_path, batch) + load_model_tensor_batch(model, batch) + batch.clear() + batch_nbytes = 0 + gc.collect() + + for name, tensor in iter_safetensors_tensors( + file_path, dtype=torch_dtype + ): + already_loaded_keys.append(name) + + if name == "model.embed_tokens.weight": + tensor_unscaled = tensor + if scale_emb != 1.0: + tensor = tensor_unscaled * float(scale_emb) + + tensor_bytes = tensor_nbytes(tensor) + if batch and batch_nbytes + tensor_bytes > 512 * 1024 * 1024: + flush_batch() + batch[name] = tensor + batch_nbytes += tensor_bytes + + if ( + "lm_head.weight" in model_keys + and not has_lm_head_weight + and "lm_head.weight" not in already_loaded_keys + ): + tied_tensor_bytes = tensor_nbytes(tensor_unscaled) + if ( + batch + and batch_nbytes + tied_tensor_bytes > 512 * 1024 * 1024 + ): + flush_batch() + batch["lm_head.weight"] = tensor_unscaled + batch_nbytes += tied_tensor_bytes + already_loaded_keys.append("lm_head.weight") + else: + tensor_bytes = tensor_nbytes(tensor) + if batch and batch_nbytes + tensor_bytes > 512 * 1024 * 1024: + flush_batch() + batch[name] = tensor + batch_nbytes += tensor_bytes + + if len(batch) >= 8 or batch_nbytes >= 512 * 1024 * 1024: + flush_batch() + + flush_batch() + continue + model_param = load_state_dict( file_path, device=torch_device, dtype=torch_dtype ) - - # Apply model-specific weight remapping - remapper = _WEIGHT_REMAPPER.get(model_type) - if remapper is not None: - model_param = remapper(model_param, config=model.hf_config) - - already_loaded_keys.extend(model_param.keys()) + model_param = remapper(model_param, config=model.hf_config) # --------------------------------------------------------- # # Scale embed_tokens on torch side before converting # --------------------------------------------------------- # - if "model.embed_tokens.weight" in model_param: - embed_tokens_torch_unscaled = model_param["model.embed_tokens.weight"] - if scale_emb != 1.0: - model_param["model.embed_tokens.weight"] = ( - embed_tokens_torch_unscaled * float(scale_emb) - ) - - # --------------------------------------------------------- # - # model_param_infini references torch.Tensor - # --------------------------------------------------------- # - model_param_infini = {} - for key in model_param.keys(): - model_param_infini[key] = infinicore.from_torch(model_param[key]) - model.load_state_dict(model_param_infini, strict=False) - infinicore.sync_device() + batch = {} + batch_nbytes = 0 + for key, tensor in model_param.items(): + already_loaded_keys.append(key) + + if key == "model.embed_tokens.weight": + embed_tokens_torch_unscaled = tensor + if scale_emb != 1.0: + tensor = embed_tokens_torch_unscaled * float(scale_emb) + + tensor_bytes = tensor_nbytes(tensor) + if batch and batch_nbytes + tensor_bytes > 512 * 1024 * 1024: + log_weight_batch("Loading remapped batch", file_path, batch) + load_model_tensor_batch(model, batch) + batch.clear() + batch_nbytes = 0 + gc.collect() + + batch[key] = tensor + batch_nbytes += tensor_bytes + if len(batch) >= 8 or batch_nbytes >= 512 * 1024 * 1024: + log_weight_batch("Loading remapped batch", file_path, batch) + load_model_tensor_batch(model, batch) + batch.clear() + batch_nbytes = 0 + gc.collect() + + if batch: + log_weight_batch("Loading remapped batch", file_path, batch) + load_model_tensor_batch(model, batch) + del model_param + gc.collect() model.process_weights_after_loading() elif os.path.exists(os.path.join(model_path, "pytorch_model.bin")): @@ -332,18 +479,19 @@ def load_model_state_dict_by_tensor( t2 = time.time() print(f" load weights over! {(t2 - t1) * 1000} ms \n") + # ============================================================================ # Common weight transformation utilities # ============================================================================ + def drop_keys( state_dict: Dict[str, torch.Tensor], substrings: List[str], ) -> Dict[str, torch.Tensor]: """Drop keys containing any of the given substrings.""" return { - k: v for k, v in state_dict.items() - if not any(sub in k for sub in substrings) + k: v for k, v in state_dict.items() if not any(sub in k for sub in substrings) } @@ -425,6 +573,7 @@ def split_fused_weight( return result + def split_fused_weight_with_sizes( state_dict: Dict[str, torch.Tensor], fused_key: str, @@ -465,6 +614,7 @@ def split_fused_weight_with_sizes( return result + # ============================================================================ # Model-specific remap functions # ============================================================================ @@ -511,18 +661,22 @@ def _remap_chatglm(state_dict, config=None): ) # 4. Rename keys - state_dict = rename_keys(state_dict, { - "transformer.encoder.layers.": "model.layers.", - "transformer.embedding.word_embeddings": "model.embed_tokens", - "transformer.encoder.final_layernorm": "model.norm", - "transformer.output_layer": "lm_head", - "self_attention.": "self_attn.", - "self_attn.dense": "self_attn.o_proj", - "mlp.dense_4h_to_h": "mlp.down_proj", - }) + state_dict = rename_keys( + state_dict, + { + "transformer.encoder.layers.": "model.layers.", + "transformer.embedding.word_embeddings": "model.embed_tokens", + "transformer.encoder.final_layernorm": "model.norm", + "transformer.output_layer": "lm_head", + "self_attention.": "self_attn.", + "self_attn.dense": "self_attn.o_proj", + "mlp.dense_4h_to_h": "mlp.down_proj", + }, + ) return state_dict + def _is_baichuan2(config): """ Baichuan1 and Baichuan2 share the same model_type "baichuan" in official HuggingFace configs, @@ -535,6 +689,7 @@ def _is_baichuan2(config): """ return config.get("vocab_size") == 125696 + def _remap_baichuan(state_dict, config=None): """Split Baichuan fused W_pack into q_proj, k_proj, v_proj and apply Baichuan2-specific fixes.""" diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 3d35941c..af8f952c 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -110,6 +110,7 @@ def __init__( enable_graph: bool = False, attn_backend: str = "default", ignore_eos: bool = False, + weight_load_mode: str = "async", ): """Initialize inference server. @@ -152,6 +153,7 @@ def __init__( self.enable_graph = enable_graph self.attn_backend = attn_backend self.ignore_eos = ignore_eos + self.weight_load_mode = weight_load_mode self.engine: AsyncLLMEngine = None @@ -183,6 +185,7 @@ async def lifespan(app: FastAPI): top_k=self.top_k, enable_graph=self.enable_graph, attn_backend=self.attn_backend, + weight_load_mode=self.weight_load_mode, ) self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") @@ -575,6 +578,7 @@ def main(): enable_graph=cfg.enable_graph, attn_backend=cfg.attn, ignore_eos=cfg.ignore_eos, + weight_load_mode=cfg.weight_load_mode, ) server.start() diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index 60e63dc6..e845e7bc 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -55,6 +55,7 @@ def __init__( enable_paged_attn=False, enable_graph=False, attn_backend="default", + weight_load_mode="async", ): import transformers import infinicore @@ -119,6 +120,7 @@ def __init__( ), enable_graph_compiling=enable_graph, attention_backend=attn_backend, + weight_load_mode=weight_load_mode, ) # Enable KV cache for generation @@ -1126,6 +1128,7 @@ def main(): cfg.enable_paged_attn, cfg.enable_graph, cfg.attn, + cfg.weight_load_mode, ) # Step 3: Evaluate each subject