From 65d8ecfe05370c6256c93d68db9a5694acf24c9f Mon Sep 17 00:00:00 2001 From: wangpengcheng Date: Thu, 4 Jun 2026 05:28:03 +0000 Subject: [PATCH 1/6] issue/407 - preallocated workspace --- csrc/engine/compiler/paged_compiler.cpp | 6 +- .../compiler/static_batching_compiler.cpp | 6 +- csrc/engine/infer_engine.cpp | 5 +- csrc/engine/infer_engine.hpp | 3 +- csrc/global_state/forward_context.hpp | 4 ++ csrc/global_state/infinilm_config.hpp | 10 ++- csrc/layers/attention/attention.cpp | 62 +++++++++++++---- csrc/layers/attention/attention.hpp | 19 +++++- .../attention/backends/attention_layer.cpp | 9 +-- .../attention/backends/attention_layer.hpp | 3 +- csrc/layers/attention/backends/flash_attn.cpp | 41 ++++++++++-- csrc/layers/attention/backends/flash_attn.hpp | 13 +++- csrc/layers/attention/backends/paged_attn.cpp | 47 ++++++++++++- csrc/layers/attention/backends/paged_attn.hpp | 13 +++- .../layers/attention/backends/static_attn.cpp | 3 +- .../layers/attention/backends/static_attn.hpp | 3 +- .../causal_lm_templates/text_causal_lm.hpp | 54 +++++++++++++-- .../layers/causal_lm_templates/text_model.hpp | 59 +++++++++++++--- csrc/layers/linear/base_linear.cpp | 59 +++++++++++++--- csrc/layers/linear/base_linear.hpp | 5 +- csrc/layers/linear/fused_linear.cpp | 20 ++++++ csrc/layers/linear/fused_linear.hpp | 6 ++ csrc/layers/linear/linear.cpp | 8 +++ csrc/layers/linear/linear.hpp | 3 + csrc/layers/mlp/mlp.cpp | 64 +++++++++++++++--- csrc/layers/mlp/mlp.hpp | 16 +++++ csrc/layers/quantization/awq.cpp | 20 ++++++ csrc/layers/quantization/awq.hpp | 7 ++ .../layers/quantization/base_quantization.hpp | 8 +++ .../quantization/compressed_tensors.cpp | 19 ++++++ .../quantization/compressed_tensors.hpp | 7 ++ csrc/layers/quantization/gptq.cpp | 10 +++ csrc/layers/quantization/gptq.hpp | 7 ++ csrc/layers/quantization/gptq_qy.cpp | 19 ++++++ csrc/layers/quantization/gptq_qy.hpp | 7 ++ .../layers/quantization/none_quantization.cpp | 18 +++++ .../layers/quantization/none_quantization.hpp | 7 ++ csrc/models/glm4/glm4_attention.cpp | 2 +- .../minicpm_sala/minicpm_sala_attention.cpp | 2 +- csrc/models/qwen3/qwen3_attention.cpp | 67 ++++++++++++++----- csrc/models/qwen3/qwen3_attention.hpp | 10 +++ .../qwen3_next/qwen3_next_attention.cpp | 2 +- csrc/pybind11/engine/engine.hpp | 14 ++-- examples/bench.py | 4 ++ python/infinilm/base_config.py | 8 ++- python/infinilm/infer_engine.py | 18 ++++- python/infinilm/llm/llm.py | 6 ++ python/infinilm/llm/scheduler.py | 14 +++- test/bench/test_benchmark.py | 3 + 49 files changed, 721 insertions(+), 99 deletions(-) diff --git a/csrc/engine/compiler/paged_compiler.cpp b/csrc/engine/compiler/paged_compiler.cpp index de6ec5d14..392ca5e69 100644 --- a/csrc/engine/compiler/paged_compiler.cpp +++ b/csrc/engine/compiler/paged_compiler.cpp @@ -109,7 +109,11 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input & graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); auto graph = std::get<0>(result->second.compiled); - auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + // Reuse the GraphTensor output captured at compile time. + // Do not call resume_from_blob_() on workspace-backed logits: + // that registers a second deleter on the same GPU block and + // triggers double free in PinnableBlockAllocator. + auto shared_output = std::get<1>(result->second.compiled); return std::make_tuple(graph, shared_output); } diff --git a/csrc/engine/compiler/static_batching_compiler.cpp b/csrc/engine/compiler/static_batching_compiler.cpp index dcd7f7143..637b3e791 100644 --- a/csrc/engine/compiler/static_batching_compiler.cpp +++ b/csrc/engine/compiler/static_batching_compiler.cpp @@ -56,7 +56,11 @@ StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled( graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); auto graph = std::get<0>(result->second.compiled); - auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + // Reuse the GraphTensor output captured at compile time. + // Do not call resume_from_blob_() on workspace-backed logits: + // that registers a second deleter on the same GPU block and + // triggers double free in PinnableBlockAllocator. + auto shared_output = std::get<1>(result->second.compiled); return std::make_tuple(graph, shared_output); } } else { diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 77d07df45..44196de34 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -15,7 +15,8 @@ InferEngine::InferEngine( const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend, - std::optional kv_cache_dtype) // Changed parameter + std::optional kv_cache_dtype, // Changed parameter + size_t max_num_batched_tokens) : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); @@ -23,7 +24,7 @@ InferEngine::InferEngine( // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = infinilm::config::ConfigFactory::createConfig(config_str); - auto infinilm_config = std::make_shared(attention_backend, this->model_config_); + auto infinilm_config = std::make_shared(attention_backend, this->model_config_, max_num_batched_tokens); // Only support offline int8 kv cache quantization in this version if (kv_cache_dtype.has_value()) { diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 70c3c1640..5aa722524 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -28,7 +28,8 @@ class InferEngine { const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, - std::optional kv_cache_dtype = std::nullopt); + std::optional kv_cache_dtype = std::nullopt, + size_t max_num_batched_tokens = 2048); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); diff --git a/csrc/global_state/forward_context.hpp b/csrc/global_state/forward_context.hpp index 2568fc7ee..2cd8eba11 100644 --- a/csrc/global_state/forward_context.hpp +++ b/csrc/global_state/forward_context.hpp @@ -1,6 +1,7 @@ #pragma once #include "../models/infinilm_model.hpp" +#include namespace infinilm::global_state { @@ -48,6 +49,9 @@ struct ForwardContext { AttentionMetadata attn_metadata; MultiModalMetadata mm_metadata; std::vector kv_cache_vec; + + // preallocated workspace for some modules + std::unordered_map preallocated_workspace; }; void initialize_forward_context(ForwardContext &forward_context); diff --git a/csrc/global_state/infinilm_config.hpp b/csrc/global_state/infinilm_config.hpp index 9b80706ca..be8da9f8c 100644 --- a/csrc/global_state/infinilm_config.hpp +++ b/csrc/global_state/infinilm_config.hpp @@ -14,13 +14,19 @@ struct InfinilmConfig { public: InfinilmConfig() = default; InfinilmConfig(const infinilm::backends::AttentionBackend &backend, - const std::shared_ptr &model_config) + const std::shared_ptr &model_config, + size_t max_num_batched_tokens) : attention_backend(backend), - model_config(model_config) {} + model_config(model_config), + max_num_batched_tokens(max_num_batched_tokens) { + const size_t max_position_embeddings = model_config->get("max_position_embeddings"); + ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings); + } public: infinilm::backends::AttentionBackend attention_backend; std::shared_ptr model_config; + size_t max_num_batched_tokens = 0; }; /** diff --git a/csrc/layers/attention/attention.cpp b/csrc/layers/attention/attention.cpp index 1b87f6fbc..2b4abd7b5 100644 --- a/csrc/layers/attention/attention.cpp +++ b/csrc/layers/attention/attention.cpp @@ -1,17 +1,20 @@ #include "attention.hpp" +#include "../../global_state/global_state.hpp" #include "../../utils.hpp" #include "../rotary_embedding/rotary_embedding.hpp" +#include namespace infinilm::layers::attention { Attention::Attention(std::shared_ptr model_config, size_t layer_idx, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { layer_idx_ = layer_idx; hidden_size_ = model_config->get("hidden_size"); head_dim_ = model_config->get("head_dim"); - const auto &dtype{model_config->get_dtype()}; size_t total_num_heads = model_config->get("num_attention_heads"); size_t total_num_kv_heads = model_config->get("num_key_value_heads"); bool use_bias = model_config->get_or("attention_bias", true); @@ -31,18 +34,21 @@ Attention::Attention(std::shared_ptr model_config qkv_proj_ = std::make_shared( hidden_size_, head_dim_, total_num_heads, total_num_kv_heads, "q_proj", "k_proj", "v_proj", register_fn, - quantization_method, use_bias, dtype, device, rank_info); + quantization_method, use_bias, dtype_, device_, rank_info); o_proj_ = this->register_module( "o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method, - use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + use_output_bias, dtype_, device_, tp_rank, tp_size, rank_info.comm); - rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device_); float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_); - init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); + init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); + + rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); + this->_initialize_preallocated_workspace(); } infinicore::Tensor Attention::forward(const infinicore::Tensor &positions, @@ -62,7 +68,8 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position size_t seq_len = shape[1]; // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); // 2. Reshape for multi-head attention auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); @@ -90,8 +97,9 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); // 7. Project output - auto output = o_proj_->forward(attn_output); - return output; + auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; } infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -106,7 +114,8 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -133,8 +142,35 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output - auto output = o_proj_->forward(attn_output); - return output; + auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; +} + +void Attention::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string attention_cache_key = std::string("Attention_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + + std::to_string(rank_qkv_output_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) { + auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); + preallocated_workspace[attention_cache_key] = attention_buffer; + } + + auto attention_buffer = preallocated_workspace.at(attention_cache_key); + const auto attention_buffer_shape = attention_buffer->shape(); + ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_}); + max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); } void init_kv_cache_quant_params(std::function register_fn, diff --git a/csrc/layers/attention/attention.hpp b/csrc/layers/attention/attention.hpp index 31f0d1fa4..b08016ea3 100644 --- a/csrc/layers/attention/attention.hpp +++ b/csrc/layers/attention/attention.hpp @@ -5,6 +5,8 @@ #include "../../global_state/global_state.hpp" #include "../linear/linear.hpp" #include "backends/attention_layer.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/nn/module.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/tensor.hpp" @@ -37,6 +39,8 @@ class Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; + void _initialize_preallocated_workspace(); + protected: std::shared_ptr qkv_proj_; std::shared_ptr o_proj_; @@ -49,13 +53,22 @@ class Attention : public infinicore::nn::Module { size_t num_key_value_heads_; size_t hidden_size_; size_t head_dim_; + infinicore::Device device_; + infinicore::DataType dtype_; // For off-line kv cache quantization INFINICORE_NN_PARAMETER(kv_cache_k_scale); INFINICORE_NN_PARAMETER(kv_cache_v_scale); + +private: + size_t rank_qkv_output_size_; + + // preallocated workspace for Attention + infinicore::Tensor max_qkv_output_; + infinicore::Tensor max_o_output_; }; void init_kv_cache_quant_params(std::function register_fn, - const infinicore::Device &device, - infinicore::nn::Parameter &kv_cache_k_scale, - infinicore::nn::Parameter &kv_cache_v_scale); + const infinicore::Device &device, + infinicore::nn::Parameter &kv_cache_k_scale, + infinicore::nn::Parameter &kv_cache_v_scale); } // namespace infinilm::layers::attention diff --git a/csrc/layers/attention/backends/attention_layer.cpp b/csrc/layers/attention/backends/attention_layer.cpp index fcaefa292..e5e39c10f 100644 --- a/csrc/layers/attention/backends/attention_layer.cpp +++ b/csrc/layers/attention/backends/attention_layer.cpp @@ -9,16 +9,17 @@ AttentionLayer::AttentionLayer(size_t num_heads, size_t layer_idx, infinicore::Tensor k_scale, infinicore::Tensor v_scale, - ::infinilm::backends::AttentionBackend attn_backend) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) { + ::infinilm::backends::AttentionBackend attn_backend, + const infinicore::Device &device) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) { switch (attn_backend) { case ::infinilm::backends::AttentionBackend::STATIC_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; case ::infinilm::backends::AttentionBackend::PAGED_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; case ::infinilm::backends::AttentionBackend::FLASH_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; default: throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend"); diff --git a/csrc/layers/attention/backends/attention_layer.hpp b/csrc/layers/attention/backends/attention_layer.hpp index 874110629..d83a79796 100644 --- a/csrc/layers/attention/backends/attention_layer.hpp +++ b/csrc/layers/attention/backends/attention_layer.hpp @@ -31,7 +31,8 @@ class AttentionLayer { size_t layer_idx, infinicore::Tensor k_scale, infinicore::Tensor v_scale, - ::infinilm::backends::AttentionBackend attention_backend); + ::infinilm::backends::AttentionBackend attention_backend, + const infinicore::Device &device); infinicore::Tensor forward(infinicore::Tensor &query, infinicore::Tensor &key, diff --git a/csrc/layers/attention/backends/flash_attn.cpp b/csrc/layers/attention/backends/flash_attn.cpp index ec7e37722..c48e5187c 100644 --- a/csrc/layers/attention/backends/flash_attn.cpp +++ b/csrc/layers/attention/backends/flash_attn.cpp @@ -1,9 +1,11 @@ #include "flash_attn.hpp" +#include "../../../global_state/global_state.hpp" #include "../../../utils.hpp" #include "infinicore/ops.hpp" #include "infinicore/ops/mha_kvcache.hpp" #include "infinicore/ops/mha_varlen.hpp" +#include namespace infinilm::layers::attention::backends { @@ -11,19 +13,26 @@ FlashAttentionImpl::FlashAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device &device) : num_heads_(num_heads), head_size_(head_size), scale_(scale), num_kv_heads_(num_kv_heads), layer_idx_(layer_idx), - head_dim_(head_size) { + head_dim_(head_size), + device_(device) { const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); if (!infinilm_config.model_config) { throw std::runtime_error("infinilm::layers::attention::backends::FlashAttentionImpl: model_config is null"); } - max_position_embeddings_ = infinilm_config.model_config->get("max_position_embeddings"); + + const auto &model_config = infinilm_config.model_config; + dtype_ = model_config->get_dtype(); + max_position_embeddings_ = model_config->get("max_position_embeddings"); + + this->_initialize_preallocated_workspace(); } infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, @@ -48,8 +57,9 @@ infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); // 2. Compute attention - infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::Tensor attn_output; if (is_prefill) { + attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); infinicore::op::mha_varlen_( attn_output, query, @@ -99,4 +109,27 @@ std::tuple FlashAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } +void FlashAttentionImpl::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string cache_key = std::string("FlashAttentionImpl_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_num_heads_" + + std::to_string(num_heads_) + "_head_dim_" + + std::to_string(head_dim_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + if (preallocated_workspace.find(cache_key) == preallocated_workspace.end()) { + auto flash_attention_impl_buffer = infinicore::Tensor::empty({max_num_batched_tokens, num_heads_, head_dim_}, dtype_, device_); + preallocated_workspace[cache_key] = flash_attention_impl_buffer; + } + + auto flash_attention_impl_buffer = preallocated_workspace.at(cache_key); + const auto buffer_shape = flash_attention_impl_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); + + max_attn_output_ = flash_attention_impl_buffer; +} } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_attn.hpp b/csrc/layers/attention/backends/flash_attn.hpp index 93f61e8ba..7a480168b 100644 --- a/csrc/layers/attention/backends/flash_attn.hpp +++ b/csrc/layers/attention/backends/flash_attn.hpp @@ -1,6 +1,8 @@ #pragma once #include "../../../global_state/global_state.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include @@ -16,7 +18,8 @@ class FlashAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); /** * @brief Forward pass with FlashAttention. @@ -43,6 +46,8 @@ class FlashAttentionImpl { const infinicore::Tensor slot_mapping) const; private: + void _initialize_preallocated_workspace(); + size_t num_heads_; size_t head_size_; float scale_; @@ -50,5 +55,11 @@ class FlashAttentionImpl { size_t layer_idx_; size_t head_dim_; // Note: head_dim equals to head_size size_t max_position_embeddings_; + infinicore::Device device_; + infinicore::DataType dtype_; + + // preallocated workspace for FlashAttentionImpl + infinicore::Tensor max_attn_output_; }; + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.cpp b/csrc/layers/attention/backends/paged_attn.cpp index a0ad70afe..0836accc9 100644 --- a/csrc/layers/attention/backends/paged_attn.cpp +++ b/csrc/layers/attention/backends/paged_attn.cpp @@ -1,21 +1,37 @@ #include "paged_attn.hpp" +#include "../../../global_state/global_state.hpp" #include "../../../utils.hpp" +#include "attention_layer.hpp" #include "infinicore/ops.hpp" +#include + namespace infinilm::layers::attention::backends { PagedAttentionImpl::PagedAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device &device) : num_heads_(num_heads), head_size_(head_size), scale_(scale), num_kv_heads_(num_kv_heads), layer_idx_(layer_idx), - head_dim_(head_size) {} + head_dim_(head_size), + device_(device) { + + const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); + if (!infinilm_config.model_config) { + throw std::runtime_error("infinilm::layers::attention::backends::PagedAttentionImpl: model_config is null"); + } + + dtype_ = infinilm_config.model_config->get_dtype(); + + this->_initialize_preallocated_workspace(); +} infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, const infinicore::Tensor &query, @@ -37,7 +53,7 @@ infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); // 2. Compute attention - infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::Tensor attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); if (is_prefill) { infinicore::op::paged_attention_prefill_( attn_output, @@ -80,4 +96,29 @@ std::tuple PagedAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } + +void PagedAttentionImpl::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string cache_key = std::string("PagedAttentionImpl_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_num_heads_" + + std::to_string(num_heads_) + "_head_dim_" + + std::to_string(head_dim_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + if (preallocated_workspace.find(cache_key) == preallocated_workspace.end()) { + auto paged_attention_impl_buffer = infinicore::Tensor::empty({max_num_batched_tokens, num_heads_, head_dim_}, dtype_, device_); + preallocated_workspace[cache_key] = paged_attention_impl_buffer; + } + + auto paged_attention_impl_buffer = preallocated_workspace.at(cache_key); + const auto buffer_shape = paged_attention_impl_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); + + max_attn_output_ = paged_attention_impl_buffer; +} + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.hpp b/csrc/layers/attention/backends/paged_attn.hpp index 4f53ea573..d4408fe3e 100644 --- a/csrc/layers/attention/backends/paged_attn.hpp +++ b/csrc/layers/attention/backends/paged_attn.hpp @@ -1,6 +1,8 @@ #pragma once #include "../../../global_state/global_state.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include @@ -16,7 +18,8 @@ class PagedAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); /** * @brief Forward pass with PagedAttention. @@ -43,11 +46,19 @@ class PagedAttentionImpl { const infinicore::Tensor slot_mapping) const; private: + void _initialize_preallocated_workspace(); + size_t num_heads_; size_t head_size_; float scale_; size_t num_kv_heads_; size_t layer_idx_; size_t head_dim_; // Note: head_dim equals to head_size + infinicore::Device device_; + infinicore::DataType dtype_; + + // preallocated workspace for PagedAttentionImpl + infinicore::Tensor max_attn_output_; }; + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/static_attn.cpp b/csrc/layers/attention/backends/static_attn.cpp index 2d1b7e11a..668f4c218 100644 --- a/csrc/layers/attention/backends/static_attn.cpp +++ b/csrc/layers/attention/backends/static_attn.cpp @@ -11,7 +11,8 @@ StaticAttentionImpl::StaticAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device & /*device*/) : num_heads_(num_heads), head_size_(head_size), scale_(scale), diff --git a/csrc/layers/attention/backends/static_attn.hpp b/csrc/layers/attention/backends/static_attn.hpp index 849d87928..00af4391e 100644 --- a/csrc/layers/attention/backends/static_attn.hpp +++ b/csrc/layers/attention/backends/static_attn.hpp @@ -18,7 +18,8 @@ class StaticAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); infinicore::Tensor forward(const AttentionLayer &layer, infinicore::Tensor &q_reshaped, // query diff --git a/csrc/layers/causal_lm_templates/text_causal_lm.hpp b/csrc/layers/causal_lm_templates/text_causal_lm.hpp index eb4f2b47f..0bdb296c4 100644 --- a/csrc/layers/causal_lm_templates/text_causal_lm.hpp +++ b/csrc/layers/causal_lm_templates/text_causal_lm.hpp @@ -1,8 +1,12 @@ #pragma once +#include "../../global_state/global_state.hpp" #include "../../models/infinilm_model.hpp" +#include "../../utils.hpp" #include "../linear/linear.hpp" #include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" +#include namespace infinilm::layers::causal_lm_templates { @@ -28,15 +32,18 @@ class TextCausalLM : public InfinilmModel { * @param device: Device to create tensors on */ TextCausalLM(std::shared_ptr model_config, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { model_config_ = model_config; size_t hidden_size = model_config->get("hidden_size"); - size_t vocab_size = model_config->get("vocab_size"); - const auto &dtype{model_config->get_dtype()}; + vocab_size_ = model_config->get("vocab_size"); model_ = this->register_module("model", model_config, device); - lm_head_ = this->register_module("lm_head", hidden_size, vocab_size, false, dtype, device); + lm_head_ = this->register_module("lm_head", hidden_size, vocab_size_, false, dtype_, device_); + + this->_initialize_preallocated_workspace(); } /** @@ -44,7 +51,13 @@ class TextCausalLM : public InfinilmModel { */ Output forward(const Input &input) const override { auto hidden_states = model_->forward(input); - auto logits = lm_head_->forward(hidden_states); + + const auto shape = hidden_states->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + + auto logits = max_logits_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, vocab_size_}); + lm_head_->forward_(logits, hidden_states); return {logits}; } @@ -55,8 +68,39 @@ class TextCausalLM : public InfinilmModel { Model &model() { return *model_; } protected: + size_t vocab_size_; + infinicore::Device device_; + infinicore::DataType dtype_; + INFINICORE_NN_MODULE(Model, model); INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); + +private: + void _initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string text_causal_lm_cache_key = std::string("TextCausalLM_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_vocab_size_" + + std::to_string(vocab_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + if (preallocated_workspace.find(text_causal_lm_cache_key) == preallocated_workspace.end()) { + auto logits_buffer = infinicore::Tensor::empty({max_num_batched_tokens, vocab_size_}, dtype_, device_); + preallocated_workspace[text_causal_lm_cache_key] = logits_buffer; + } + + auto logits_buffer = preallocated_workspace.at(text_causal_lm_cache_key); + const auto logits_buffer_shape = logits_buffer->shape(); + ASSERT(logits_buffer_shape[0] == max_num_batched_tokens && logits_buffer_shape[1] == vocab_size_); + + max_logits_ = logits_buffer; + } + + // preallocated workspace for TextCausalLM + infinicore::Tensor max_logits_; }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/causal_lm_templates/text_model.hpp b/csrc/layers/causal_lm_templates/text_model.hpp index 62a52798b..23a9f13a6 100644 --- a/csrc/layers/causal_lm_templates/text_model.hpp +++ b/csrc/layers/causal_lm_templates/text_model.hpp @@ -1,11 +1,16 @@ #pragma once #include "../../config/model_config.hpp" +#include "../../global_state/global_state.hpp" #include "../../models/infinilm_model.hpp" +#include "../../utils.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/nn/embedding.hpp" #include "infinicore/nn/rmsnorm.hpp" #include "infinicore/tensor.hpp" #include +#include #include namespace infinilm::layers::causal_lm_templates { @@ -24,30 +29,37 @@ template class TextModel : public infinicore::nn::Module { public: TextModel(std::shared_ptr model_config, - const infinicore::Device &device) { - const auto &dtype{model_config->get_dtype()}; - size_t vocab_size = model_config->get("vocab_size"); - size_t hidden_size = model_config->get("hidden_size"); + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { + vocab_size_ = model_config->get("vocab_size"); + hidden_size_ = model_config->get("hidden_size"); size_t max_position_embeddings = model_config->get("max_position_embeddings"); size_t num_hidden_layers = model_config->get("num_hidden_layers"); double rope_theta = model_config->get("rope_theta"); double rms_norm_eps = model_config->get("rms_norm_eps"); - embed_tokens_ = this->register_module("embed_tokens", vocab_size, hidden_size, std::nullopt, dtype, device); + embed_tokens_ = this->register_module("embed_tokens", vocab_size_, hidden_size_, std::nullopt, dtype_, device_); layers_.reserve(num_hidden_layers); for (size_t i = 0; i < num_hidden_layers; ++i) { - layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device)); + layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device_)); } - norm_ = this->register_module("norm", hidden_size, rms_norm_eps, dtype, device); + norm_ = this->register_module("norm", hidden_size_, rms_norm_eps, dtype_, device_); + + this->_initialize_preallocated_workspace(); } infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const { auto input_ids = input.input_ids.value(); auto positions = input.position_ids.value(); // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] - auto hidden_states = embed_tokens_->forward(input_ids); + const auto shape = input_ids->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + auto hidden_states = max_hidden_states_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); + embed_tokens_->forward_(hidden_states, input_ids); // 2. Process through all decoder layers size_t num_layers = layers_.size(); @@ -64,6 +76,7 @@ class TextModel : public infinicore::nn::Module { } infinicore::Tensor forward_naive(const infinilm::InfinilmModel::Input &input) const { + // Don't use preallocated workspace in forward_naive function. auto input_ids = input.input_ids.value(); auto positions = input.position_ids.value(); auto hidden_states = embed_tokens_->forward(input_ids); @@ -78,6 +91,7 @@ class TextModel : public infinicore::nn::Module { infinicore::Tensor forward_embeds(const infinicore::Tensor &inputs_embeds, const infinicore::Tensor &position_ids) const { + // Don't use preallocated workspace in forward_embeds function. auto hidden_states = inputs_embeds; // Process through all decoder layers @@ -98,10 +112,39 @@ class TextModel : public infinicore::nn::Module { return embed_tokens_->forward(input_ids); } +private: + void _initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string text_model_cache_key = std::string("TextModel_max_num_batched_tokens_") + std::to_string(max_num_batched_tokens) + "_hidden_size_" + std::to_string(hidden_size_) + "_dtype_" + infinicore::toString(dtype_) + "_device_" + device_.toString(); + + if (preallocated_workspace.find(text_model_cache_key) == preallocated_workspace.end()) { + auto text_model_buffer = infinicore::Tensor::empty({max_num_batched_tokens, hidden_size_}, dtype_, device_); + preallocated_workspace[text_model_cache_key] = text_model_buffer; + } + + auto text_model_buffer = preallocated_workspace.at(text_model_cache_key); + const auto text_model_buffer_shape = text_model_buffer->shape(); + ASSERT(text_model_buffer_shape[0] == max_num_batched_tokens && text_model_buffer_shape[1] == hidden_size_); + + max_hidden_states_ = text_model_buffer; + } + protected: + size_t vocab_size_; + size_t hidden_size_; + infinicore::Device device_; + infinicore::DataType dtype_; + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); INFINICORE_NN_MODULE_VEC(DecoderLayer, layers); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + +private: + // preallocated workspace for TextModel + infinicore::Tensor max_hidden_states_; }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/linear/base_linear.cpp b/csrc/layers/linear/base_linear.cpp index eebc482c4..25b292600 100644 --- a/csrc/layers/linear/base_linear.cpp +++ b/csrc/layers/linear/base_linear.cpp @@ -43,10 +43,25 @@ infinicore::Tensor BaseLinear::compute_linear(infinicore::Tensor &input) const { return quantization_->forward(params, input, has_bias_, alpha_); } +void BaseLinear::compute_linear_(infinicore::Tensor &output, infinicore::Tensor &input) const { + // Build params map from direct parameters only (not state_dict which uses a + // static local and is not thread-safe across RankWorker threads). + infinilm::quantization::ParamsMap params; + for (const auto &[name, param] : parameters_) { + params[name] = static_cast(param); + } + + quantization_->forward_(output, params, input, has_bias_, alpha_); +} + infinicore::Tensor BaseLinear::forward(infinicore::Tensor &input) const { return compute_linear(input); } +void BaseLinear::forward_(infinicore::Tensor &output, infinicore::Tensor &input) const { + compute_linear_(output, input); +} + infinicore::Tensor BaseLinear::forward(infinicore::Tensor &input, infinicore::Tensor &residual) const { auto output = compute_linear(input); infinicore::op::add_(output, output, residual); @@ -60,7 +75,9 @@ void BaseLinear::process_weights_after_loading() { } auto new_quant = quantization_->process_weights_after_loading(params, device_); - if (!new_quant) return; + if (!new_quant) { + return; + } for (auto &[name, param] : parameters_) { param = infinicore::nn::Parameter(); @@ -68,7 +85,9 @@ void BaseLinear::process_weights_after_loading() { for (const auto &[name, tensor] : params) { auto it = parameters_.find(name); - if (it == parameters_.end()) continue; + if (it == parameters_.end()) { + continue; + } it->second = infinicore::nn::Parameter(tensor); } @@ -79,43 +98,61 @@ void BaseLinear::process_weights_after_loading() { infinicore::Tensor BaseLinear::weight() const { auto it = parameters_.find("weight"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("qweight"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::bias() const { auto it = parameters_.find("bias"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::weight_scale() const { auto it = parameters_.find("weight_scale"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("scales"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::weight_zeros() const { auto it = parameters_.find("weight_zeros"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("qzeros"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::gidx() const { auto it = parameters_.find("g_idx"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::get_param(const std::string &name) const { auto it = parameters_.find(name); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } diff --git a/csrc/layers/linear/base_linear.hpp b/csrc/layers/linear/base_linear.hpp index a304f4573..52ac27c40 100644 --- a/csrc/layers/linear/base_linear.hpp +++ b/csrc/layers/linear/base_linear.hpp @@ -1,8 +1,8 @@ #pragma once -#include "infinicore/ops.hpp" #include "../quantization/quantization.hpp" #include "infinicore/nn/module.hpp" +#include "infinicore/ops.hpp" #include #include @@ -23,6 +23,8 @@ class BaseLinear : public infinicore::nn::Module { // Forward pass: output = input @ weight.T + bias infinicore::Tensor forward(infinicore::Tensor &input) const; + void forward_(infinicore::Tensor &output, infinicore::Tensor &input) const; + // Forward pass with residual connection infinicore::Tensor forward(infinicore::Tensor &input, infinicore::Tensor &residual) const; @@ -57,6 +59,7 @@ class BaseLinear : public infinicore::nn::Module { protected: infinicore::Tensor compute_linear(infinicore::Tensor &input) const; + void compute_linear_(infinicore::Tensor &output, infinicore::Tensor &input) const; size_t in_features_; size_t out_features_; diff --git a/csrc/layers/linear/fused_linear.cpp b/csrc/layers/linear/fused_linear.cpp index 2c0520901..dd0da59dc 100644 --- a/csrc/layers/linear/fused_linear.cpp +++ b/csrc/layers/linear/fused_linear.cpp @@ -70,6 +70,17 @@ QKVParallelLinear::forward_split(infinicore::Tensor &input) { return std::make_tuple(q_out, k_out, v_out); } +std::tuple +QKVParallelLinear::forward_split_(infinicore::Tensor &output, infinicore::Tensor &input) { + this->forward_(output, input); + + auto q_out = output->narrow({{2, 0, q_out_size_}}); + auto k_out = output->narrow({{2, q_out_size_, k_out_size_}}); + auto v_out = output->narrow({{2, q_out_size_ + k_out_size_, v_out_size_}}); + + return std::make_tuple(q_out, k_out, v_out); +} + bool QKVParallelLinear::has_q_bias() const { return q_bias_; } bool QKVParallelLinear::has_k_bias() const { return k_bias_; } bool QKVParallelLinear::has_v_bias() const { return v_bias_; } @@ -144,6 +155,15 @@ std::tuple GateUpParallelLinear::forward return std::make_tuple(gate_output, up_output); } +std::tuple +GateUpParallelLinear::forward_split_(infinicore::Tensor &output, infinicore::Tensor &input) { + this->forward_(output, input); + auto cols = output->shape()[2]; + auto gate_output = output->narrow({{2, 0, cols / 2}}); + auto up_output = output->narrow({{2, cols / 2, cols / 2}}); + return std::make_tuple(gate_output, up_output); +} + bool GateUpParallelLinear::has_gate_bias() const { return gate_bias_; } bool GateUpParallelLinear::has_up_bias() const { return up_bias_; } diff --git a/csrc/layers/linear/fused_linear.hpp b/csrc/layers/linear/fused_linear.hpp index 6e4a34856..b85aa3dda 100644 --- a/csrc/layers/linear/fused_linear.hpp +++ b/csrc/layers/linear/fused_linear.hpp @@ -43,6 +43,9 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear { std::tuple forward_split(infinicore::Tensor &input); + std::tuple + forward_split_(infinicore::Tensor &output, infinicore::Tensor &input); + bool has_q_bias() const; bool has_k_bias() const; bool has_v_bias() const; @@ -109,6 +112,9 @@ class GateUpParallelLinear : public infinilm::nn::ColumnParallelLinear { std::tuple forward_split(infinicore::Tensor &input); + std::tuple + forward_split_(infinicore::Tensor &output, infinicore::Tensor &input); + bool has_gate_bias() const; bool has_up_bias() const; diff --git a/csrc/layers/linear/linear.cpp b/csrc/layers/linear/linear.cpp index 84982409f..f73abbb29 100644 --- a/csrc/layers/linear/linear.cpp +++ b/csrc/layers/linear/linear.cpp @@ -93,6 +93,14 @@ infinicore::Tensor RowParallelLinear::forward(infinicore::Tensor &input) const { return output; } +void RowParallelLinear::forward_(infinicore::Tensor &output, infinicore::Tensor &input) const { + BaseLinear::forward_(output, input); + + if ((tp_size_ > 1) && (communicator_ != nullptr)) { + infinicore::op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_); + } +} + std::string RowParallelLinear::extra_repr() const { return "RowParallelLinear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; } diff --git a/csrc/layers/linear/linear.hpp b/csrc/layers/linear/linear.hpp index 566cee77c..abae08f66 100644 --- a/csrc/layers/linear/linear.hpp +++ b/csrc/layers/linear/linear.hpp @@ -70,6 +70,9 @@ class RowParallelLinear : public BaseLinear { infinicclComm_t communicator = nullptr); infinicore::Tensor forward(infinicore::Tensor &input) const; + + void forward_(infinicore::Tensor &output, infinicore::Tensor &input) const; + std::string extra_repr() const; protected: diff --git a/csrc/layers/mlp/mlp.cpp b/csrc/layers/mlp/mlp.cpp index f7604c505..00e756d3d 100644 --- a/csrc/layers/mlp/mlp.cpp +++ b/csrc/layers/mlp/mlp.cpp @@ -1,13 +1,16 @@ #include "mlp.hpp" #include "../../global_state/global_state.hpp" +#include "../../utils.hpp" #include "infinicore/ops.hpp" +#include namespace infinilm::layers::mlp { MLP::MLP(std::shared_ptr model_config, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { - const auto &dtype{model_config->get_dtype()}; hidden_size_ = model_config->get("hidden_size"); intermediate_size_ = model_config->get("intermediate_size"); use_bias_ = model_config->get_or("mlp_bias", false); @@ -20,20 +23,65 @@ MLP::MLP(std::shared_ptr model_config, auto register_fn = [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }; gate_up_proj_ = std::make_shared( hidden_size_, intermediate_size_, "gate_proj", "up_proj", register_fn, - quantization_method, use_bias_, dtype, device, rank_info); + quantization_method, use_bias_, dtype_, device_, rank_info); down_proj_ = this->register_module( "down_proj", intermediate_size_, hidden_size_, quantization_method, - use_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); + use_bias_, dtype_, device_, tp_rank, tp_size, rank_info.comm); + + rank_gate_up_output_size_ = gate_up_proj_->out_features() / static_cast(tp_size); + rank_intermediate_size_ = rank_gate_up_output_size_ / 2; + this->_initialize_preallocated_workspace(); } infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { + const auto shape = hidden_states->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + // 1. Project to gate and up auto hidden_states_mutable = hidden_states; - auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); + auto gate_up_output = max_gate_up_output_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, rank_gate_up_output_size_}); + auto [gate, up] = gate_up_proj_->forward_split_(gate_up_output, hidden_states_mutable); + // 2. Apply SwiGLU: silu(gate) * up - auto intermediate = infinicore::op::swiglu(up, gate); + auto intermediate = max_intermediate_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, rank_intermediate_size_}); + infinicore::op::swiglu_(intermediate, up, gate); + // 3. Project down - auto output = down_proj_->forward(intermediate); - return output; + auto down_output = max_down_output_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); + down_proj_->forward_(down_output, intermediate); + return down_output; } + +void MLP::_initialize_preallocated_workspace() { + + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string mlp_cache_key = std::string("MLP_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_gate_up_output_size_" + + std::to_string(rank_gate_up_output_size_) + "_rank_intermediate_size_" + + std::to_string(rank_intermediate_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + size_t max_gate_up_intermediate_size = std::max(rank_gate_up_output_size_, rank_intermediate_size_); + size_t max_output_size = max_gate_up_intermediate_size + hidden_size_; + + if (preallocated_workspace.find(mlp_cache_key) == preallocated_workspace.end()) { + auto mlp_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); + preallocated_workspace[mlp_cache_key] = mlp_buffer; + } + + auto mlp_buffer = preallocated_workspace.at(mlp_cache_key); + const auto buffer_shape = mlp_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_gate_up_output_ = mlp_buffer->narrow({{0, 0, max_num_batched_tokens * rank_gate_up_output_size_}})->view({max_num_batched_tokens, rank_gate_up_output_size_}); + max_intermediate_ = mlp_buffer->narrow({{0, 0, max_num_batched_tokens * rank_intermediate_size_}})->view({max_num_batched_tokens, rank_intermediate_size_}); + max_down_output_ = mlp_buffer->narrow({{0, max_num_batched_tokens * max_gate_up_intermediate_size, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); +} + } // namespace infinilm::layers::mlp diff --git a/csrc/layers/mlp/mlp.hpp b/csrc/layers/mlp/mlp.hpp index 91349fe9b..73e5da02e 100644 --- a/csrc/layers/mlp/mlp.hpp +++ b/csrc/layers/mlp/mlp.hpp @@ -2,7 +2,10 @@ #include "../../config/model_config.hpp" #include "../linear/linear.hpp" +#include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" namespace infinilm::layers::mlp { @@ -51,6 +54,19 @@ class MLP : public infinicore::nn::Module { size_t hidden_size_; size_t intermediate_size_; bool use_bias_; + infinicore::Device device_; + infinicore::DataType dtype_; + +private: + void _initialize_preallocated_workspace(); + + size_t rank_gate_up_output_size_; + size_t rank_intermediate_size_; + + // preallocated workspace for MLP + infinicore::Tensor max_gate_up_output_; + infinicore::Tensor max_intermediate_; + infinicore::Tensor max_down_output_; }; } // namespace infinilm::layers::mlp diff --git a/csrc/layers/quantization/awq.cpp b/csrc/layers/quantization/awq.cpp index 50e830f44..1c07c6dbe 100644 --- a/csrc/layers/quantization/awq.cpp +++ b/csrc/layers/quantization/awq.cpp @@ -52,6 +52,26 @@ infinicore::Tensor AWQ::forward( return infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt); } +void AWQ::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto qweight = params.at("qweight"); + auto scales = params.at("scales"); + auto qzeros = params.at("qzeros"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_w4a16_awq_(output, input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt); +} + std::vector AWQ::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/awq.hpp b/csrc/layers/quantization/awq.hpp index 383e574aa..797092cb4 100644 --- a/csrc/layers/quantization/awq.hpp +++ b/csrc/layers/quantization/awq.hpp @@ -38,6 +38,13 @@ class AWQ : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/base_quantization.hpp b/csrc/layers/quantization/base_quantization.hpp index 1fd261bdf..0ed82cb35 100644 --- a/csrc/layers/quantization/base_quantization.hpp +++ b/csrc/layers/quantization/base_quantization.hpp @@ -59,6 +59,14 @@ class BaseQuantization : public std::enable_shared_from_this { bool has_bias, float alpha = 1.0f) const = 0; + // In-place forward pass. + virtual void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const = 0; + // Dimension for fused-split (gate/up, q/k/v) of a column-parallel weight. // For NoneQuantization weight [out, in], split is on dim0. // For AWQ qweight [in, out/pack], split is on dim1. diff --git a/csrc/layers/quantization/compressed_tensors.cpp b/csrc/layers/quantization/compressed_tensors.cpp index 66a4a3ef6..45c46fb3e 100644 --- a/csrc/layers/quantization/compressed_tensors.cpp +++ b/csrc/layers/quantization/compressed_tensors.cpp @@ -43,6 +43,25 @@ infinicore::Tensor CompressedTensors::forward( return infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight, weight_scale, bias_opt); } +void CompressedTensors::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto weight = params.at("weight"); + auto weight_scale = params.at("weight_scale"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_w8a8i8_(output, input_contiguous->contiguous(), weight, weight_scale, bias_opt); +} + std::vector CompressedTensors::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/compressed_tensors.hpp b/csrc/layers/quantization/compressed_tensors.hpp index dcf65c2e0..2bac396aa 100644 --- a/csrc/layers/quantization/compressed_tensors.hpp +++ b/csrc/layers/quantization/compressed_tensors.hpp @@ -25,6 +25,13 @@ class CompressedTensors : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq.cpp b/csrc/layers/quantization/gptq.cpp index e7688be50..972aa130c 100644 --- a/csrc/layers/quantization/gptq.cpp +++ b/csrc/layers/quantization/gptq.cpp @@ -36,6 +36,16 @@ infinicore::Tensor GPTQ::forward( "Call process_weights_after_loading() first."); } +void GPTQ::forward_( + infinicore::Tensor & /*output*/, + const ParamsMap & /*params*/, + const infinicore::Tensor & /*input*/, + bool /*has_bias*/, + float /*alpha*/) const { + throw std::runtime_error("GPTQ_W4A16 must be converted to GPTQ_QY before forward pass. " + "Call process_weights_after_loading() first."); +} + std::shared_ptr GPTQ::process_weights_after_loading( ParamsMap ¶ms, const infinicore::Device &device) const { diff --git a/csrc/layers/quantization/gptq.hpp b/csrc/layers/quantization/gptq.hpp index 455dde2cc..598be78fb 100644 --- a/csrc/layers/quantization/gptq.hpp +++ b/csrc/layers/quantization/gptq.hpp @@ -34,6 +34,13 @@ class GPTQ : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq_qy.cpp b/csrc/layers/quantization/gptq_qy.cpp index 4098e452d..42ac9b26e 100644 --- a/csrc/layers/quantization/gptq_qy.cpp +++ b/csrc/layers/quantization/gptq_qy.cpp @@ -48,6 +48,25 @@ infinicore::Tensor GPTQ_QY::forward( return output; } +void GPTQ_QY::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto qweight = params.at("qweight"); + auto qzeros = params.at("qzeros"); + auto scales = params.at("scales"); + + infinicore::op::linear_w4a16_gptq_qy_(output, input_contiguous->contiguous(), qweight, scales, qzeros, 0, 4); + + if (has_bias) { + auto bias = params.at("bias"); + infinicore::op::add_(output, output, bias->as_strided(output->shape(), {0, 0, 1})); + } +} + std::vector GPTQ_QY::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq_qy.hpp b/csrc/layers/quantization/gptq_qy.hpp index 634b4aaf7..22635fb4d 100644 --- a/csrc/layers/quantization/gptq_qy.hpp +++ b/csrc/layers/quantization/gptq_qy.hpp @@ -112,6 +112,13 @@ class GPTQ_QY : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + // Split fused linear parameters into named sub-parameters std::vector split_params( const std::unordered_map ¶ms, diff --git a/csrc/layers/quantization/none_quantization.cpp b/csrc/layers/quantization/none_quantization.cpp index e1f67a7d1..3525b4e13 100644 --- a/csrc/layers/quantization/none_quantization.cpp +++ b/csrc/layers/quantization/none_quantization.cpp @@ -38,6 +38,24 @@ infinicore::Tensor NoneQuantization::forward( return infinicore::op::linear(input_contiguous->contiguous(), weight->contiguous(), bias_opt, alpha); } +void NoneQuantization::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto weight = params.at("weight"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_(output, input_contiguous->contiguous(), weight->contiguous(), bias_opt, alpha); +} + std::vector NoneQuantization::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/none_quantization.hpp b/csrc/layers/quantization/none_quantization.hpp index 18fe4bf17..baee47c6a 100644 --- a/csrc/layers/quantization/none_quantization.hpp +++ b/csrc/layers/quantization/none_quantization.hpp @@ -27,6 +27,13 @@ class NoneQuantization : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/models/glm4/glm4_attention.cpp b/csrc/models/glm4/glm4_attention.cpp index c5851cc26..579b0e89f 100644 --- a/csrc/models/glm4/glm4_attention.cpp +++ b/csrc/models/glm4/glm4_attention.cpp @@ -57,7 +57,7 @@ Glm4Attention::Glm4Attention(std::shared_ptr mode attn_ = std::make_shared( num_attention_heads_, head_dim_, scaling_, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); // KV Cache quantization scale initialization infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); diff --git a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp index 60fe4ee47..0413f274b 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp +++ b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp @@ -49,7 +49,7 @@ AttentionBase::AttentionBase(std::shared_ptr mode float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); infinilm::layers::attention::init_kv_cache_quant_params([this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }, device, kv_cache_k_scale_, kv_cache_v_scale_); diff --git a/csrc/models/qwen3/qwen3_attention.cpp b/csrc/models/qwen3/qwen3_attention.cpp index bff9c8d73..676d2567d 100644 --- a/csrc/models/qwen3/qwen3_attention.cpp +++ b/csrc/models/qwen3/qwen3_attention.cpp @@ -2,17 +2,19 @@ #include "../../global_state/global_state.hpp" #include "../../layers/attention/attention.hpp" #include "../../utils.hpp" +#include namespace infinilm::models::qwen3 { Qwen3Attention::Qwen3Attention(std::shared_ptr model_config, size_t layer_idx, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { layer_idx_ = layer_idx; hidden_size_ = model_config->get("hidden_size"); head_dim_ = model_config->get("head_dim"); - const auto &dtype{model_config->get_dtype()}; size_t total_num_heads = model_config->get("num_attention_heads"); size_t total_num_kv_heads = model_config->get("num_key_value_heads"); bool use_bias = model_config->get_or("attention_bias", true); @@ -35,21 +37,24 @@ Qwen3Attention::Qwen3Attention(std::shared_ptr mo qkv_proj_ = std::make_shared( hidden_size_, head_dim_, total_num_heads, total_num_kv_heads, "q_proj", "k_proj", "v_proj", register_fn, - quantization_method, use_bias, dtype, device, rank_info); + quantization_method, use_bias, dtype_, device_, rank_info); o_proj_ = this->register_module( "o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method, - use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + use_output_bias, dtype_, device_, tp_rank, tp_size, rank_info.comm); - rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device_); float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_); - INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype_, device_); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype_, device_); - infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); + infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); + + rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); + this->_initialize_preallocated_workspace(); } infinicore::Tensor Qwen3Attention::forward(const infinicore::Tensor &positions, @@ -70,7 +75,8 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos size_t seq_len = shape[1]; // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); @@ -100,9 +106,10 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos // 6. Attn Backend calculate auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); - // 7. Project output - auto output = o_proj_->forward(attn_output); - return output; + // 6. Project output + auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; } infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -118,7 +125,8 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -147,6 +155,35 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output - return o_proj_->forward(attn_output); + auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; +} + +void Qwen3Attention::_initialize_preallocated_workspace() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string attention_cache_key = std::string("Qwen3Attention_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + + std::to_string(rank_qkv_output_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) { + auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); + preallocated_workspace[attention_cache_key] = attention_buffer; + } + + auto attention_buffer = preallocated_workspace.at(attention_cache_key); + const auto attention_buffer_shape = attention_buffer->shape(); + ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_}); + max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); } + } // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3/qwen3_attention.hpp b/csrc/models/qwen3/qwen3_attention.hpp index 44b69f386..0133a9339 100644 --- a/csrc/models/qwen3/qwen3_attention.hpp +++ b/csrc/models/qwen3/qwen3_attention.hpp @@ -25,6 +25,8 @@ class Qwen3Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; + void _initialize_preallocated_workspace(); + protected: std::shared_ptr qkv_proj_; std::shared_ptr o_proj_; @@ -39,9 +41,17 @@ class Qwen3Attention : public infinicore::nn::Module { size_t num_key_value_heads_; size_t hidden_size_; size_t head_dim_; + infinicore::Device device_; + infinicore::DataType dtype_; // For off-line kv cache quantization INFINICORE_NN_PARAMETER(kv_cache_k_scale); INFINICORE_NN_PARAMETER(kv_cache_v_scale); + + size_t rank_qkv_output_size_; + + // preallocated workspace for Attention + infinicore::Tensor max_qkv_output_; + infinicore::Tensor max_o_output_; }; } // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3_next/qwen3_next_attention.cpp b/csrc/models/qwen3_next/qwen3_next_attention.cpp index 67fd38082..5cf469d5a 100644 --- a/csrc/models/qwen3_next/qwen3_next_attention.cpp +++ b/csrc/models/qwen3_next/qwen3_next_attention.cpp @@ -49,7 +49,7 @@ Qwen3NextAttention::Qwen3NextAttention(std::shared_ptr(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 8e470984e..d3c0476f3 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -1,5 +1,6 @@ #include "../../engine/infer_engine.hpp" #include "infinicore/tensor.hpp" +#include #include #include @@ -39,7 +40,8 @@ inline void bind_infer_engine(py::module &m) { std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, - std::optional kv_cache_dtype) { + std::optional kv_cache_dtype, + size_t max_num_batched_tokens) { return std::make_shared( model_path, dist, @@ -47,7 +49,8 @@ inline void bind_infer_engine(py::module &m) { cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, infinilm::backends::parse_attention_backend(attention_backend), - kv_cache_dtype); + kv_cache_dtype, + max_num_batched_tokens); }), py::arg("model_path") = "", py::arg("distributed_config") = distributed::DistConfig(), @@ -55,7 +58,8 @@ inline void bind_infer_engine(py::module &m) { py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, py::arg("attention_backend") = "default", - py::arg("kv_cache_dtype") = py::none()) + py::arg("kv_cache_dtype") = py::none(), + py::arg("max_num_batched_tokens") = 2048) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -83,8 +87,8 @@ inline void bind_infer_engine(py::module &m) { .def( "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { - auto cfg = self.get_cache_config(); - return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) + auto cfg = self.get_cache_config(); + return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) .def("__repr__", [](const InferEngine &self) { return ""; }); py::class_(infer_engine, "Input") diff --git a/examples/bench.py b/examples/bench.py index 6672d6d7d..29d15211c 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -169,6 +169,7 @@ def __init__( cache_config=None, enable_graph=False, attn_backend="default", + max_num_batched_tokens: int = None, ) -> None: model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -182,6 +183,7 @@ def __init__( enable_graph_compiling=enable_graph, attention_backend=attn_backend, kv_cache_dtype=cfg.kv_cache_dtype, + max_num_batched_tokens=max_num_batched_tokens, ) # ---------------------------------------------------------------------------- # @@ -281,6 +283,7 @@ def run( enable_paged_attn = cfg.enable_paged_attn enable_graph = cfg.enable_graph attn_backend = cfg.attn + max_num_batched_tokens = cfg.max_num_batched_tokens if isinstance(batch_size, int): batch_size = [batch_size] @@ -322,6 +325,7 @@ def run( cache_config=cache_config, enable_graph=enable_graph, attn_backend=attn_backend, + max_num_batched_tokens=max_num_batched_tokens, ) # ---------------------------------------------------------------------------- # diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd459..c0fd306b7 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -44,7 +44,6 @@ class BaseConfig: """InfiniLM Unified Config - Command line argument parser""" def __init__(self): - self.parser = argparse.ArgumentParser(description="InfiniLM Unified Config") self._add_common_args() self.args, self.extra = self.parser.parse_known_args() @@ -70,6 +69,7 @@ def __init__(self): self.batch_size = self.args.batch_size self.max_batch_size = self.args.max_batch_size + self.max_num_batched_tokens = self.args.max_num_batched_tokens self.input_len = self.args.input_len self.output_len = self.args.output_len self.max_new_tokens = self.args.max_new_tokens @@ -155,6 +155,12 @@ def _add_common_args(self): default=8, help="maximum batch size for server", ) + self.parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=None, + help="maximum number of batched tokens for paged attention", + ) self.parser.add_argument( "--input-len", type=parse_list, default=10, help="input sequence length" ) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 844989f43..5de13a580 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -30,6 +30,7 @@ def read_hf_config(model_path): ) return config_dict + # config.json (required) defines model architecture, while generation_config.json # (optional) defines generation behavior. They are kept as separate readers # because: 1) config.json must exist and requires model_type validation, @@ -43,6 +44,7 @@ def read_hf_generation_config(model_path): return json.load(f) return {} + @dataclass class GenerationConfig: max_new_tokens: int | None = None @@ -65,6 +67,7 @@ def __init__( enable_graph_compiling=False, attention_backend="default", kv_cache_dtype=None, + max_num_batched_tokens: int | None = None, ): self.hf_config = read_hf_config(model_path) self.hf_generation_config = read_hf_generation_config(model_path) @@ -72,6 +75,12 @@ def __init__( if device is None: device = infinicore.device() + max_position_embeddings = self.hf_config["max_position_embeddings"] + if max_num_batched_tokens is None: + max_num_batched_tokens = max_position_embeddings + assert 512 <= max_num_batched_tokens <= max_position_embeddings + self.max_num_batched_tokens = max_num_batched_tokens + hf_config_str = json.dumps(self.hf_config) super().__init__( hf_config_str, @@ -85,6 +94,7 @@ def __init__( if kv_cache_dtype is not None else None ), + max_num_batched_tokens, ) self.use_cache = False @@ -375,10 +385,14 @@ def reset_cache(self, cache_config): super().reset_cache(cache_config) def state_dict_keyname(self): - return sorted({name for state_dict in super().state_dict() for name in state_dict.keys()}) + return sorted( + {name for state_dict in super().state_dict() for name in state_dict.keys()} + ) def load_state_dict(self, state_dict, strict=None): - super().load_params({name: param._underlying for name, param in state_dict.items()}) + super().load_params( + {name: param._underlying for name, param in state_dict.items()} + ) def process_weights_after_loading(self): super().process_weights_after_loading() diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 953b35e59..3cecec1b9 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -6,6 +6,7 @@ - AsyncLLM class for asynchronous streaming (server use) """ +import os import asyncio import time import uuid @@ -74,6 +75,7 @@ class EngineConfig: enable_graph: bool = False attn_backend: str = "default" skip_load: bool = False + max_num_batched_tokens: int | None = None class LLMEngine: @@ -92,7 +94,9 @@ def __init__(self, config: EngineConfig): distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, attention_backend=config.attn_backend, + max_num_batched_tokens=config.max_num_batched_tokens, ) + self.max_num_batched_tokens = self.model_engine.max_num_batched_tokens # Load model weights if not self.config.skip_load: @@ -117,10 +121,12 @@ def __init__(self, config: EngineConfig): cache_config = PagedKVCacheConfig( num_blocks=config.num_blocks, block_size=config.block_size ) + self.scheduler = Scheduler( max_batch_size=config.max_batch_size, num_blocks=config.num_blocks, block_size=config.block_size, + max_num_batched_tokens=self.max_num_batched_tokens, ) logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}") else: diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index c5f4921a9..9ba7edcd6 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -39,6 +39,7 @@ def __init__( max_batch_size: int = 16, num_blocks: int = 512, block_size: int = 256, + max_num_batched_tokens: int = 1024, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() @@ -47,6 +48,8 @@ def __init__( self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) self.block_size = block_size + self.max_num_batched_tokens = max_num_batched_tokens + def add_request(self, request: InferenceRequest): if request is not None: request.status = RequestStatus.WAITING @@ -56,9 +59,13 @@ def schedule(self) -> Optional[SchedulerOutput]: """Schedule and return batch of requests to execute.""" scheduled_requests = [] is_prefill = False + current_num_batched_tokens = 0 # Process Waiting queue (prefill phase) - while len(scheduled_requests) < self.max_batch_size: + while ( + len(scheduled_requests) < self.max_batch_size + and current_num_batched_tokens < self.max_num_batched_tokens + ): try: req = self.waiting_queue.sync_q.get_nowait() except queue.Empty: @@ -95,6 +102,11 @@ def schedule(self) -> Optional[SchedulerOutput]: req.status = RequestStatus.RUNNING scheduled_requests.append(req) + # TODO + # num_tokens_this_step = req.get_prompt_length() - req.num_cached_tokens + # current_num_batched_tokens += num_tokens_this_step + assert False + # Return prefill batch if any waiting requests were scheduled if scheduled_requests: is_prefill = True diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index c15c950fe..476e45de5 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -55,6 +55,7 @@ def __init__( enable_paged_attn=False, enable_graph=False, attn_backend="default", + max_num_batched_tokens: int | None = None, ): import transformers import infinicore @@ -119,6 +120,7 @@ def __init__( ), enable_graph_compiling=enable_graph, attention_backend=attn_backend, + max_num_batched_tokens=max_num_batched_tokens, ) # Enable KV cache for generation @@ -1125,6 +1127,7 @@ def main(): cfg.bench, cfg.enable_paged_attn, cfg.enable_graph, + cfg.max_num_batched_tokens, cfg.attn, ) From e1b65cc785766d17ef635a75e8bff94e471a83fe Mon Sep 17 00:00:00 2001 From: MaYuhang <2902139028@qq.com> Date: Thu, 4 Jun 2026 08:04:53 +0000 Subject: [PATCH 2/6] issue/407 - pybind: release GIL in forward() to avoid blocking other Python threads --- csrc/pybind11/engine/engine.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index d3c0476f3..74fb22c95 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -80,6 +80,10 @@ inline void bind_infer_engine(py::module &m) { .def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)") .def( "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { + // IMPORTANT: Release the GIL before calling forward() to allow other Python threads + // to run concurrently during inference (which may block for a long time). + // Do NOT remove this — without it, the GIL is held throughout inference and will + // deadlock or stall any other Python thread (e.g., request handling, scheduling). py::gil_scoped_release release; return self.forward(input); }, From ed1c8c033e0ecd98c5466ea10e81391b9998070d Mon Sep 17 00:00:00 2001 From: MaYuhang <2902139028@qq.com> Date: Thu, 4 Jun 2026 08:06:30 +0000 Subject: [PATCH 3/6] issue/407 - fix: early token budget check --- python/infinilm/llm/scheduler.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index 9ba7edcd6..aa61d07a7 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -98,15 +98,27 @@ def schedule(self) -> Optional[SchedulerOutput]: ) ) + num_tokens_this_step = ( + req.get_prompt_length() - req.num_cached_tokens + ) + if ( + current_num_batched_tokens + num_tokens_this_step + >= self.max_num_batched_tokens + ): + if req.num_cached_tokens > 0: + self.cache_manager.free_blocks(req.block_table) + req.block_table = [] + req.slot_mapping = [] + req.num_cached_tokens = 0 + + self.waiting_queue.sync_q.put(req) + break + + current_num_batched_tokens += num_tokens_this_step req.num_blocks = len(req.block_table) req.status = RequestStatus.RUNNING scheduled_requests.append(req) - # TODO - # num_tokens_this_step = req.get_prompt_length() - req.num_cached_tokens - # current_num_batched_tokens += num_tokens_this_step - assert False - # Return prefill batch if any waiting requests were scheduled if scheduled_requests: is_prefill = True From 6bb2040884e09bd36f6cff3179fa287812a32ba5 Mon Sep 17 00:00:00 2001 From: wangpengcheng Date: Tue, 9 Jun 2026 08:06:59 +0000 Subject: [PATCH 4/6] refactor with register_inference_buffer. --- csrc/engine/rank_worker.cpp | 8 + csrc/global_state/forward_context.hpp | 8 +- csrc/global_state/infinilm_config.hpp | 9 +- csrc/global_state/workspace_manager.hpp | 161 ++++++++++++++++++ csrc/layers/attention/attention.cpp | 78 ++++++--- csrc/layers/attention/attention.hpp | 13 +- csrc/layers/attention/backends/flash_attn.cpp | 36 ++-- csrc/layers/attention/backends/flash_attn.hpp | 9 +- csrc/layers/attention/backends/paged_attn.cpp | 38 +++-- csrc/layers/attention/backends/paged_attn.hpp | 10 +- .../causal_lm_templates/text_causal_lm.hpp | 55 +++--- .../layers/causal_lm_templates/text_model.hpp | 66 ++++--- csrc/layers/mlp/mlp.cpp | 71 +++++--- csrc/layers/mlp/mlp.hpp | 14 +- csrc/models/qwen3/qwen3_attention.cpp | 76 ++++++--- csrc/models/qwen3/qwen3_attention.hpp | 7 +- python/infinilm/llm/llm.py | 1 - python/infinilm/llm/scheduler.py | 4 +- 18 files changed, 474 insertions(+), 190 deletions(-) create mode 100644 csrc/global_state/workspace_manager.hpp diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 87568fd6a..e24475d72 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -278,6 +278,14 @@ void RankWorker::thread_loop() { if (!model_) { throw std::runtime_error("Failed to create model"); } + + infinicore::context::syncStream(); + + if (infinilm_config_->enable_workspace_manager) { + forward_context_.workspace_manager.finalize_and_bind(rank_info_.device); + } + infinicore::context::syncStream(); + if (enable_graph_compiling_) { compiler_ = std::make_unique(model_, barrier_); } diff --git a/csrc/global_state/forward_context.hpp b/csrc/global_state/forward_context.hpp index 2cd8eba11..b56548002 100644 --- a/csrc/global_state/forward_context.hpp +++ b/csrc/global_state/forward_context.hpp @@ -1,7 +1,9 @@ #pragma once #include "../models/infinilm_model.hpp" -#include +#include "../utils.hpp" +#include "workspace_manager.hpp" +#include namespace infinilm::global_state { @@ -49,9 +51,7 @@ struct ForwardContext { AttentionMetadata attn_metadata; MultiModalMetadata mm_metadata; std::vector kv_cache_vec; - - // preallocated workspace for some modules - std::unordered_map preallocated_workspace; + WorkspaceManager workspace_manager; }; void initialize_forward_context(ForwardContext &forward_context); diff --git a/csrc/global_state/infinilm_config.hpp b/csrc/global_state/infinilm_config.hpp index be8da9f8c..7e39c3f26 100644 --- a/csrc/global_state/infinilm_config.hpp +++ b/csrc/global_state/infinilm_config.hpp @@ -19,14 +19,19 @@ struct InfinilmConfig { : attention_backend(backend), model_config(model_config), max_num_batched_tokens(max_num_batched_tokens) { - const size_t max_position_embeddings = model_config->get("max_position_embeddings"); - ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings); + + if (max_num_batched_tokens > 0) { + const size_t max_position_embeddings = model_config->get("max_position_embeddings"); + ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings); + enable_workspace_manager = true; + } } public: infinilm::backends::AttentionBackend attention_backend; std::shared_ptr model_config; size_t max_num_batched_tokens = 0; + bool enable_workspace_manager{false}; }; /** diff --git a/csrc/global_state/workspace_manager.hpp b/csrc/global_state/workspace_manager.hpp new file mode 100644 index 000000000..628ba5cf7 --- /dev/null +++ b/csrc/global_state/workspace_manager.hpp @@ -0,0 +1,161 @@ +#pragma once + +#include "../models/infinilm_model.hpp" +#include "../utils.hpp" +#include +#include +#include +#include +#include +#include + +namespace infinilm::global_state { + +// /** +// * @brief Unified GPU inference workspace manager. +// * +// * Phase 1: modules register buffer layouts via ``register_buffer``. +// * Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views. +// */ +// class WorkspaceManager { +// public: +// using BindFn = std::function; + +// WorkspaceManager() = default; +// ~WorkspaceManager() = default; + +// /** +// * @brief Register a buffer appended at the current scratch_buffer tail. +// * +// * @param name Unique cache key; duplicate keys share one slot. +// * @param shape Tensor shape for the bound view. +// * @param dtype Element type of the bound view. +// * @param device Device on which scratch_buffer is allocated. +// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. +// */ +// void register_buffer(const std::string &name, +// const infinicore::Shape &shape, +// const infinicore::DataType &dtype, +// const infinicore::Device &device, +// BindFn bind_fn) { +// register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true); +// } + +// /** +// * @brief Register a buffer pinned at a fixed byte offset. +// * +// * @param name Unique cache key; duplicate keys share one slot. +// * @param offset Byte offset in scratch_buffer (currently only 0 is supported). +// * @param shape Tensor shape for the bound view. +// * @param dtype Element type of the bound view. +// * @param device Device on which scratch_buffer is allocated. +// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. +// */ +// void register_buffer(const std::string &name, +// size_t offset, +// const infinicore::Shape &shape, +// const infinicore::DataType &dtype, +// const infinicore::Device &device, +// BindFn bind_fn) { +// ASSERT(0 == offset); +// register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false); +// } + +// /** +// * @brief Allocate scratch_buffer and run all registered bind callbacks. +// * +// * @param device Device on which scratch_buffer is allocated. +// */ +// void finalize_and_bind(const infinicore::Device &device) { +// ASSERT(!finalized_); +// if (total_bytes_ == 0) { +// finalized_ = true; +// return; +// } + +// ASSERT(device.getType() != infinicore::Device::Type::CPU); + +// scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device); + +// spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0); + +// for (auto &[name, reg] : registrations_) { +// auto *base_ptr = scratch_buffer_->data() + reg.offset; +// auto view = infinicore::Tensor::from_blob(static_cast(base_ptr), reg.shape, reg.dtype, device); +// inference_buffers_[name] = view; +// for (auto &bind_fn : reg.bind_callbacks) { +// bind_fn(view); +// } +// } + +// finalized_ = true; +// } + +// private: +// /** @brief Metadata for one registered region in scratch_buffer. */ +// struct BufferRegistration { +// size_t offset{0}; +// size_t aligned_bytes{0}; +// infinicore::Shape shape; +// infinicore::DataType dtype; +// infinicore::Device device; +// std::vector bind_callbacks; +// }; + +// void register_buffer_impl(const std::string &name, +// size_t offset, +// const infinicore::Shape &shape, +// const infinicore::DataType &dtype, +// const infinicore::Device &device, +// BindFn bind_fn, +// bool bump_tail) { +// ASSERT(!finalized_); +// ASSERT(device.getType() != infinicore::Device::Type::CPU); + +// auto compute_numel = [](const infinicore::Shape &shape) { +// size_t numel = 1; +// for (const auto dim : shape) { +// numel *= dim; +// } +// return numel; +// }; + +// auto align_up = [](size_t n, size_t alignment = 512) { +// return (n + alignment - 1) & ~(alignment - 1); +// }; + +// const size_t actual_bytes = compute_numel(shape) * infinicore::dsize(dtype); +// const size_t aligned_bytes = align_up(actual_bytes); + +// if (registrations_.find(name) == registrations_.end()) { +// BufferRegistration reg; +// reg.offset = offset; +// reg.aligned_bytes = aligned_bytes; +// reg.shape = shape; +// reg.dtype = dtype; +// reg.device = device; + +// if (bump_tail) { +// total_bytes_ += aligned_bytes; +// } else { +// total_bytes_ = std::max(total_bytes_, offset + aligned_bytes); +// } +// registrations_.emplace(name, std::move(reg)); +// } + +// auto ® = registrations_.at(name); +// ASSERT(reg.aligned_bytes == aligned_bytes); +// ASSERT(reg.shape == shape); +// ASSERT(reg.dtype == dtype); +// ASSERT(reg.device == device); +// reg.bind_callbacks.push_back(std::move(bind_fn)); +// } + +// size_t total_bytes_{0}; +// bool finalized_{false}; +// infinicore::Tensor scratch_buffer_; +// std::unordered_map registrations_; +// std::unordered_map inference_buffers_; +// }; + +}; // namespace infinilm::global_state \ No newline at end of file diff --git a/csrc/layers/attention/attention.cpp b/csrc/layers/attention/attention.cpp index 2b4abd7b5..747e980d3 100644 --- a/csrc/layers/attention/attention.cpp +++ b/csrc/layers/attention/attention.cpp @@ -3,6 +3,7 @@ #include "../../utils.hpp" #include "../rotary_embedding/rotary_embedding.hpp" #include +#include namespace infinilm::layers::attention { @@ -48,7 +49,10 @@ Attention::Attention(std::shared_ptr model_config init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); - this->_initialize_preallocated_workspace(); + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor Attention::forward(const infinicore::Tensor &positions, @@ -68,8 +72,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position size_t seq_len = shape[1]; // 1. Project Q, K, V - auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); - auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); @@ -96,10 +105,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position // 5. Attn Backend calculate auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); - // 7. Project output - auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); - o_proj_->forward_(o_output, attn_output); - return o_output; + // 6. Project output + if (enable_workspace_manager_) { + auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); } infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -114,8 +126,13 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); - auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -142,16 +159,21 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output - auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); - o_proj_->forward_(o_output, attn_output); - return o_output; + if (enable_workspace_manager_) { + auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); } -void Attention::_initialize_preallocated_workspace() { +void Attention::_register_inference_buffer() { const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); - auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + ASSERT(rank_qkv_output_size_ > 0 && hidden_size_ > 0); + const std::string attention_cache_key = std::string("Attention_max_num_batched_tokens_") + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + std::to_string(rank_qkv_output_size_) + "_hidden_size_" @@ -159,18 +181,22 @@ void Attention::_initialize_preallocated_workspace() { + infinicore::toString(dtype_) + "_device_" + device_.toString(); - size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); - if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) { - auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); - preallocated_workspace[attention_cache_key] = attention_buffer; - } - - auto attention_buffer = preallocated_workspace.at(attention_cache_key); - const auto attention_buffer_shape = attention_buffer->shape(); - ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); - - max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_}); - max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); + const size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + const infinicore::Shape attention_buffer_shape = {max_num_batched_tokens * max_output_size}; + workspace_manager.register_buffer( + attention_cache_key, + attention_buffer_shape, + dtype_, + device_, + [this, max_num_batched_tokens, max_output_size](const infinicore::Tensor &attention_buffer) { + const auto attention_buffer_shape = attention_buffer->shape(); + ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}}) + ->view({max_num_batched_tokens, rank_qkv_output_size_}); + max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}}) + ->view({max_num_batched_tokens, hidden_size_}); + }); } void init_kv_cache_quant_params(std::function register_fn, diff --git a/csrc/layers/attention/attention.hpp b/csrc/layers/attention/attention.hpp index b08016ea3..00cab342b 100644 --- a/csrc/layers/attention/attention.hpp +++ b/csrc/layers/attention/attention.hpp @@ -5,8 +5,6 @@ #include "../../global_state/global_state.hpp" #include "../linear/linear.hpp" #include "backends/attention_layer.hpp" -#include "infinicore/device.hpp" -#include "infinicore/dtype.hpp" #include "infinicore/nn/module.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/tensor.hpp" @@ -39,7 +37,7 @@ class Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; - void _initialize_preallocated_workspace(); + void _register_inference_buffer(); protected: std::shared_ptr qkv_proj_; @@ -61,11 +59,10 @@ class Attention : public infinicore::nn::Module { INFINICORE_NN_PARAMETER(kv_cache_v_scale); private: - size_t rank_qkv_output_size_; - - // preallocated workspace for Attention - infinicore::Tensor max_qkv_output_; - infinicore::Tensor max_o_output_; + bool enable_workspace_manager_{false}; + size_t rank_qkv_output_size_{0}; + infinicore::Tensor max_qkv_output_; // inference buffer for Attention + infinicore::Tensor max_o_output_; // inference buffer for Attention }; void init_kv_cache_quant_params(std::function register_fn, const infinicore::Device &device, diff --git a/csrc/layers/attention/backends/flash_attn.cpp b/csrc/layers/attention/backends/flash_attn.cpp index c48e5187c..12a9e4116 100644 --- a/csrc/layers/attention/backends/flash_attn.cpp +++ b/csrc/layers/attention/backends/flash_attn.cpp @@ -32,7 +32,10 @@ FlashAttentionImpl::FlashAttentionImpl(size_t num_heads, dtype_ = model_config->get_dtype(); max_position_embeddings_ = model_config->get("max_position_embeddings"); - this->_initialize_preallocated_workspace(); + enable_workspace_manager_ = infinilm_config.enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, @@ -59,7 +62,11 @@ infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, // 2. Compute attention infinicore::Tensor attn_output; if (is_prefill) { - attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + if (enable_workspace_manager_) { + attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + } else { + attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); + } infinicore::op::mha_varlen_( attn_output, query, @@ -109,9 +116,9 @@ std::tuple FlashAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } -void FlashAttentionImpl::_initialize_preallocated_workspace() { +void FlashAttentionImpl::_register_inference_buffer() { const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); - auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; const std::string cache_key = std::string("FlashAttentionImpl_max_num_batched_tokens_") @@ -121,15 +128,16 @@ void FlashAttentionImpl::_initialize_preallocated_workspace() { + infinicore::toString(dtype_) + "_device_" + device_.toString(); - if (preallocated_workspace.find(cache_key) == preallocated_workspace.end()) { - auto flash_attention_impl_buffer = infinicore::Tensor::empty({max_num_batched_tokens, num_heads_, head_dim_}, dtype_, device_); - preallocated_workspace[cache_key] = flash_attention_impl_buffer; - } - - auto flash_attention_impl_buffer = preallocated_workspace.at(cache_key); - const auto buffer_shape = flash_attention_impl_buffer->shape(); - ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); - - max_attn_output_ = flash_attention_impl_buffer; + const infinicore::Shape flash_attn_buffer_shape = {max_num_batched_tokens, num_heads_, head_dim_}; + workspace_manager.register_buffer( + cache_key, + flash_attn_buffer_shape, + dtype_, + device_, + [this, max_num_batched_tokens](const infinicore::Tensor &flash_attention_impl_buffer) { + const auto buffer_shape = flash_attention_impl_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); + max_attn_output_ = flash_attention_impl_buffer; + }); } } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_attn.hpp b/csrc/layers/attention/backends/flash_attn.hpp index 7a480168b..334e39a47 100644 --- a/csrc/layers/attention/backends/flash_attn.hpp +++ b/csrc/layers/attention/backends/flash_attn.hpp @@ -1,8 +1,6 @@ #pragma once #include "../../../global_state/global_state.hpp" -#include "infinicore/device.hpp" -#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include @@ -46,7 +44,9 @@ class FlashAttentionImpl { const infinicore::Tensor slot_mapping) const; private: - void _initialize_preallocated_workspace(); + void _register_inference_buffer(); + bool enable_workspace_manager_{false}; + infinicore::Tensor max_attn_output_; // inference buffer for FlashAttentionImpl size_t num_heads_; size_t head_size_; @@ -57,9 +57,6 @@ class FlashAttentionImpl { size_t max_position_embeddings_; infinicore::Device device_; infinicore::DataType dtype_; - - // preallocated workspace for FlashAttentionImpl - infinicore::Tensor max_attn_output_; }; } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.cpp b/csrc/layers/attention/backends/paged_attn.cpp index 0836accc9..b97788881 100644 --- a/csrc/layers/attention/backends/paged_attn.cpp +++ b/csrc/layers/attention/backends/paged_attn.cpp @@ -2,7 +2,6 @@ #include "../../../global_state/global_state.hpp" #include "../../../utils.hpp" -#include "attention_layer.hpp" #include "infinicore/ops.hpp" #include @@ -30,7 +29,10 @@ PagedAttentionImpl::PagedAttentionImpl(size_t num_heads, dtype_ = infinilm_config.model_config->get_dtype(); - this->_initialize_preallocated_workspace(); + enable_workspace_manager_ = infinilm_config.enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, @@ -53,7 +55,12 @@ infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); // 2. Compute attention - infinicore::Tensor attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + infinicore::Tensor attn_output; + if (enable_workspace_manager_) { + attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + } else { + attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); + } if (is_prefill) { infinicore::op::paged_attention_prefill_( attn_output, @@ -97,9 +104,9 @@ std::tuple PagedAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } -void PagedAttentionImpl::_initialize_preallocated_workspace() { +void PagedAttentionImpl::_register_inference_buffer() { const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); - auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; const std::string cache_key = std::string("PagedAttentionImpl_max_num_batched_tokens_") @@ -109,16 +116,17 @@ void PagedAttentionImpl::_initialize_preallocated_workspace() { + infinicore::toString(dtype_) + "_device_" + device_.toString(); - if (preallocated_workspace.find(cache_key) == preallocated_workspace.end()) { - auto paged_attention_impl_buffer = infinicore::Tensor::empty({max_num_batched_tokens, num_heads_, head_dim_}, dtype_, device_); - preallocated_workspace[cache_key] = paged_attention_impl_buffer; - } - - auto paged_attention_impl_buffer = preallocated_workspace.at(cache_key); - const auto buffer_shape = paged_attention_impl_buffer->shape(); - ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); - - max_attn_output_ = paged_attention_impl_buffer; + const infinicore::Shape paged_attn_buffer_shape = {max_num_batched_tokens, num_heads_, head_dim_}; + workspace_manager.register_buffer( + cache_key, + paged_attn_buffer_shape, + dtype_, + device_, + [this, max_num_batched_tokens](const infinicore::Tensor &paged_attention_impl_buffer) { + const auto buffer_shape = paged_attention_impl_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); + max_attn_output_ = paged_attention_impl_buffer; + }); } } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.hpp b/csrc/layers/attention/backends/paged_attn.hpp index d4408fe3e..f5c625b45 100644 --- a/csrc/layers/attention/backends/paged_attn.hpp +++ b/csrc/layers/attention/backends/paged_attn.hpp @@ -1,8 +1,6 @@ #pragma once #include "../../../global_state/global_state.hpp" -#include "infinicore/device.hpp" -#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include @@ -46,7 +44,9 @@ class PagedAttentionImpl { const infinicore::Tensor slot_mapping) const; private: - void _initialize_preallocated_workspace(); + void _register_inference_buffer(); + bool enable_workspace_manager_{false}; + infinicore::Tensor max_attn_output_; // inference buffer for PagedAttentionImpl size_t num_heads_; size_t head_size_; @@ -56,9 +56,5 @@ class PagedAttentionImpl { size_t head_dim_; // Note: head_dim equals to head_size infinicore::Device device_; infinicore::DataType dtype_; - - // preallocated workspace for PagedAttentionImpl - infinicore::Tensor max_attn_output_; }; - } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/causal_lm_templates/text_causal_lm.hpp b/csrc/layers/causal_lm_templates/text_causal_lm.hpp index 0bdb296c4..0f4a578b4 100644 --- a/csrc/layers/causal_lm_templates/text_causal_lm.hpp +++ b/csrc/layers/causal_lm_templates/text_causal_lm.hpp @@ -43,7 +43,10 @@ class TextCausalLM : public InfinilmModel { model_ = this->register_module("model", model_config, device); lm_head_ = this->register_module("lm_head", hidden_size, vocab_size_, false, dtype_, device_); - this->_initialize_preallocated_workspace(); + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } /** @@ -51,13 +54,18 @@ class TextCausalLM : public InfinilmModel { */ Output forward(const Input &input) const override { auto hidden_states = model_->forward(input); + infinicore::Tensor logits; + + if (enable_workspace_manager_) { + const auto shape = hidden_states->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + logits = max_logits_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, vocab_size_}); + lm_head_->forward_(logits, hidden_states); + } else { + logits = lm_head_->forward(hidden_states); + } - const auto shape = hidden_states->shape(); - const size_t bs = shape[0]; - const size_t seq_len = shape[1]; - - auto logits = max_logits_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, vocab_size_}); - lm_head_->forward_(logits, hidden_states); return {logits}; } @@ -68,7 +76,7 @@ class TextCausalLM : public InfinilmModel { Model &model() { return *model_; } protected: - size_t vocab_size_; + size_t vocab_size_{0}; infinicore::Device device_; infinicore::DataType dtype_; @@ -76,10 +84,11 @@ class TextCausalLM : public InfinilmModel { INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); private: - void _initialize_preallocated_workspace() { + void _register_inference_buffer() { const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); - auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + ASSERT(vocab_size_ > 0); const std::string text_causal_lm_cache_key = std::string("TextCausalLM_max_num_batched_tokens_") + std::to_string(max_num_batched_tokens) + "_vocab_size_" @@ -87,20 +96,22 @@ class TextCausalLM : public InfinilmModel { + infinicore::toString(dtype_) + "_device_" + device_.toString(); - if (preallocated_workspace.find(text_causal_lm_cache_key) == preallocated_workspace.end()) { - auto logits_buffer = infinicore::Tensor::empty({max_num_batched_tokens, vocab_size_}, dtype_, device_); - preallocated_workspace[text_causal_lm_cache_key] = logits_buffer; - } - - auto logits_buffer = preallocated_workspace.at(text_causal_lm_cache_key); - const auto logits_buffer_shape = logits_buffer->shape(); - ASSERT(logits_buffer_shape[0] == max_num_batched_tokens && logits_buffer_shape[1] == vocab_size_); - - max_logits_ = logits_buffer; + const infinicore::Shape logits_shape = {max_num_batched_tokens, vocab_size_}; + workspace_manager.register_buffer( + text_causal_lm_cache_key, + 0, + logits_shape, + dtype_, + device_, + [this, max_num_batched_tokens](const infinicore::Tensor &logits_buffer) { + const auto logits_buffer_shape = logits_buffer->shape(); + ASSERT(logits_buffer_shape[0] == max_num_batched_tokens && logits_buffer_shape[1] == vocab_size_); + max_logits_ = logits_buffer; + }); } - // preallocated workspace for TextCausalLM - infinicore::Tensor max_logits_; + bool enable_workspace_manager_{false}; + infinicore::Tensor max_logits_; // inference buffer for TextCausalLM }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/causal_lm_templates/text_model.hpp b/csrc/layers/causal_lm_templates/text_model.hpp index 23a9f13a6..d615920df 100644 --- a/csrc/layers/causal_lm_templates/text_model.hpp +++ b/csrc/layers/causal_lm_templates/text_model.hpp @@ -4,8 +4,6 @@ #include "../../global_state/global_state.hpp" #include "../../models/infinilm_model.hpp" #include "../../utils.hpp" -#include "infinicore/device.hpp" -#include "infinicore/dtype.hpp" #include "infinicore/nn/embedding.hpp" #include "infinicore/nn/rmsnorm.hpp" #include "infinicore/tensor.hpp" @@ -48,18 +46,26 @@ class TextModel : public infinicore::nn::Module { norm_ = this->register_module("norm", hidden_size_, rms_norm_eps, dtype_, device_); - this->_initialize_preallocated_workspace(); + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const { auto input_ids = input.input_ids.value(); auto positions = input.position_ids.value(); // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] - const auto shape = input_ids->shape(); - const size_t bs = shape[0]; - const size_t seq_len = shape[1]; - auto hidden_states = max_hidden_states_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); - embed_tokens_->forward_(hidden_states, input_ids); + infinicore::Tensor hidden_states; + if (enable_workspace_manager_) { + const auto shape = input_ids->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + hidden_states = max_hidden_states_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); + embed_tokens_->forward_(hidden_states, input_ids); + } else { + hidden_states = embed_tokens_->forward(input_ids); + } // 2. Process through all decoder layers size_t num_layers = layers_.size(); @@ -113,28 +119,34 @@ class TextModel : public infinicore::nn::Module { } private: - void _initialize_preallocated_workspace() { + void _register_inference_buffer() { const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); - auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; - - const std::string text_model_cache_key = std::string("TextModel_max_num_batched_tokens_") + std::to_string(max_num_batched_tokens) + "_hidden_size_" + std::to_string(hidden_size_) + "_dtype_" + infinicore::toString(dtype_) + "_device_" + device_.toString(); - - if (preallocated_workspace.find(text_model_cache_key) == preallocated_workspace.end()) { - auto text_model_buffer = infinicore::Tensor::empty({max_num_batched_tokens, hidden_size_}, dtype_, device_); - preallocated_workspace[text_model_cache_key] = text_model_buffer; - } - - auto text_model_buffer = preallocated_workspace.at(text_model_cache_key); - const auto text_model_buffer_shape = text_model_buffer->shape(); - ASSERT(text_model_buffer_shape[0] == max_num_batched_tokens && text_model_buffer_shape[1] == hidden_size_); - - max_hidden_states_ = text_model_buffer; + ASSERT((hidden_size_ > 0) && (vocab_size_ > 0)); + + const std::string text_model_cache_key = std::string("TextModel_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + const infinicore::Shape hidden_states_shape = {max_num_batched_tokens, hidden_size_}; + workspace_manager.register_buffer( + text_model_cache_key, + hidden_states_shape, + dtype_, + device_, + [this, max_num_batched_tokens](const infinicore::Tensor &hidden_states_buffer) { + const auto hidden_states_buffer_shape = hidden_states_buffer->shape(); + ASSERT(hidden_states_buffer_shape[0] == max_num_batched_tokens && hidden_states_buffer_shape[1] == hidden_size_); + max_hidden_states_ = hidden_states_buffer; + }); } protected: - size_t vocab_size_; - size_t hidden_size_; + size_t vocab_size_{0}; + size_t hidden_size_{0}; infinicore::Device device_; infinicore::DataType dtype_; @@ -143,8 +155,8 @@ class TextModel : public infinicore::nn::Module { INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); private: - // preallocated workspace for TextModel - infinicore::Tensor max_hidden_states_; + bool enable_workspace_manager_{false}; + infinicore::Tensor max_hidden_states_; // inference buffer for TextModel }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/mlp/mlp.cpp b/csrc/layers/mlp/mlp.cpp index 00e756d3d..13d0ac581 100644 --- a/csrc/layers/mlp/mlp.cpp +++ b/csrc/layers/mlp/mlp.cpp @@ -30,10 +30,29 @@ MLP::MLP(std::shared_ptr model_config, rank_gate_up_output_size_ = gate_up_proj_->out_features() / static_cast(tp_size); rank_intermediate_size_ = rank_gate_up_output_size_ / 2; - this->_initialize_preallocated_workspace(); + + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { + if (enable_workspace_manager_) { + return this->_forward_with_inference_buffer(hidden_states); + } + + // 1. Project to gate and up + auto hidden_states_mutable = hidden_states; + auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); + // 2. Apply SwiGLU: silu(gate) * up + auto intermediate = infinicore::op::swiglu(up, gate); + // 3. Project down + auto output = down_proj_->forward(intermediate); + return output; +} + +infinicore::Tensor MLP::_forward_with_inference_buffer(const infinicore::Tensor &hidden_states) const { const auto shape = hidden_states->shape(); const size_t bs = shape[0]; const size_t seq_len = shape[1]; @@ -53,12 +72,13 @@ infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { return down_output; } -void MLP::_initialize_preallocated_workspace() { - +void MLP::_register_inference_buffer() { const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); - auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + ASSERT(rank_gate_up_output_size_ > 0 && rank_intermediate_size_ > 0 && hidden_size_ > 0 && intermediate_size_ > 0); + const std::string mlp_cache_key = std::string("MLP_max_num_batched_tokens_") + std::to_string(max_num_batched_tokens) + "_rank_gate_up_output_size_" + std::to_string(rank_gate_up_output_size_) + "_rank_intermediate_size_" @@ -67,21 +87,34 @@ void MLP::_initialize_preallocated_workspace() { + infinicore::toString(dtype_) + "_device_" + device_.toString(); - size_t max_gate_up_intermediate_size = std::max(rank_gate_up_output_size_, rank_intermediate_size_); - size_t max_output_size = max_gate_up_intermediate_size + hidden_size_; - - if (preallocated_workspace.find(mlp_cache_key) == preallocated_workspace.end()) { - auto mlp_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); - preallocated_workspace[mlp_cache_key] = mlp_buffer; - } - - auto mlp_buffer = preallocated_workspace.at(mlp_cache_key); - const auto buffer_shape = mlp_buffer->shape(); - ASSERT(buffer_shape[0] == max_num_batched_tokens * max_output_size); - - max_gate_up_output_ = mlp_buffer->narrow({{0, 0, max_num_batched_tokens * rank_gate_up_output_size_}})->view({max_num_batched_tokens, rank_gate_up_output_size_}); - max_intermediate_ = mlp_buffer->narrow({{0, 0, max_num_batched_tokens * rank_intermediate_size_}})->view({max_num_batched_tokens, rank_intermediate_size_}); - max_down_output_ = mlp_buffer->narrow({{0, max_num_batched_tokens * max_gate_up_intermediate_size, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); + auto align_up = [](size_t n, size_t alignment = 256) { + return (n + alignment - 1) & ~(alignment - 1); + }; + + const size_t rank_gate_up_output_size_aligned = align_up(rank_gate_up_output_size_); + const size_t rank_intermediate_size_aligned = align_up(rank_gate_up_output_size_aligned + rank_intermediate_size_); + const size_t max_output_size = rank_intermediate_size_aligned + hidden_size_; + + const infinicore::Shape mlp_buffer_shape = {max_num_batched_tokens * max_output_size}; + workspace_manager.register_buffer( + mlp_cache_key, + mlp_buffer_shape, + dtype_, + device_, + [this, max_num_batched_tokens, rank_gate_up_output_size_aligned, rank_intermediate_size_aligned, + max_output_size](const infinicore::Tensor &mlp_buffer) { + const auto buffer_shape = mlp_buffer->shape(); + ASSERT(buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_gate_up_output_ = mlp_buffer->narrow({{0, 0, max_num_batched_tokens * rank_gate_up_output_size_}}) + ->view({max_num_batched_tokens, rank_gate_up_output_size_}); + max_intermediate_ = mlp_buffer->narrow({{0, max_num_batched_tokens * rank_gate_up_output_size_aligned, + max_num_batched_tokens * rank_intermediate_size_}}) + ->view({max_num_batched_tokens, rank_intermediate_size_}); + max_down_output_ = mlp_buffer->narrow({{0, max_num_batched_tokens * rank_intermediate_size_aligned, + max_num_batched_tokens * hidden_size_}}) + ->view({max_num_batched_tokens, hidden_size_}); + }); } } // namespace infinilm::layers::mlp diff --git a/csrc/layers/mlp/mlp.hpp b/csrc/layers/mlp/mlp.hpp index 73e5da02e..3887a2f23 100644 --- a/csrc/layers/mlp/mlp.hpp +++ b/csrc/layers/mlp/mlp.hpp @@ -2,8 +2,6 @@ #include "../../config/model_config.hpp" #include "../linear/linear.hpp" -#include "infinicore/device.hpp" -#include "infinicore/dtype.hpp" #include "infinicore/nn/module.hpp" #include "infinicore/tensor.hpp" @@ -58,13 +56,15 @@ class MLP : public infinicore::nn::Module { infinicore::DataType dtype_; private: - void _initialize_preallocated_workspace(); + infinicore::Tensor _forward_with_inference_buffer(const infinicore::Tensor &hidden_states) const; - size_t rank_gate_up_output_size_; - size_t rank_intermediate_size_; + void _register_inference_buffer(); - // preallocated workspace for MLP - infinicore::Tensor max_gate_up_output_; + bool enable_workspace_manager_{false}; + size_t rank_gate_up_output_size_{0}; + size_t rank_intermediate_size_{0}; + + infinicore::Tensor max_gate_up_output_; // inference buffer for MLP infinicore::Tensor max_intermediate_; infinicore::Tensor max_down_output_; }; diff --git a/csrc/models/qwen3/qwen3_attention.cpp b/csrc/models/qwen3/qwen3_attention.cpp index 676d2567d..9f806d967 100644 --- a/csrc/models/qwen3/qwen3_attention.cpp +++ b/csrc/models/qwen3/qwen3_attention.cpp @@ -3,6 +3,7 @@ #include "../../layers/attention/attention.hpp" #include "../../utils.hpp" #include +#include namespace infinilm::models::qwen3 { @@ -54,7 +55,10 @@ Qwen3Attention::Qwen3Attention(std::shared_ptr mo infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); - this->_initialize_preallocated_workspace(); + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor Qwen3Attention::forward(const infinicore::Tensor &positions, @@ -75,8 +79,13 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos size_t seq_len = shape[1]; // 1. Project Q, K, V - auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); - auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); @@ -107,9 +116,12 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); // 6. Project output - auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); - o_proj_->forward_(o_output, attn_output); - return o_output; + if (enable_workspace_manager_) { + auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); } infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -125,8 +137,13 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); - auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -155,16 +172,21 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output - auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); - o_proj_->forward_(o_output, attn_output); - return o_output; + if (enable_workspace_manager_) { + auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); } -void Qwen3Attention::_initialize_preallocated_workspace() { +void Qwen3Attention::_register_inference_buffer() { const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); - auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + ASSERT(rank_qkv_output_size_ > 0 && hidden_size_ > 0); + const std::string attention_cache_key = std::string("Qwen3Attention_max_num_batched_tokens_") + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + std::to_string(rank_qkv_output_size_) + "_hidden_size_" @@ -172,18 +194,22 @@ void Qwen3Attention::_initialize_preallocated_workspace() { + infinicore::toString(dtype_) + "_device_" + device_.toString(); - size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); - if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) { - auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_); - preallocated_workspace[attention_cache_key] = attention_buffer; - } - - auto attention_buffer = preallocated_workspace.at(attention_cache_key); - const auto attention_buffer_shape = attention_buffer->shape(); - ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); - - max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_}); - max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_}); + const size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + const infinicore::Shape attention_buffer_shape = {max_num_batched_tokens * max_output_size}; + workspace_manager.register_buffer( + attention_cache_key, + attention_buffer_shape, + dtype_, + device_, + [this, max_num_batched_tokens, max_output_size](const infinicore::Tensor &attention_buffer) { + const auto attention_buffer_shape = attention_buffer->shape(); + ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); + + max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}}) + ->view({max_num_batched_tokens, rank_qkv_output_size_}); + max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}}) + ->view({max_num_batched_tokens, hidden_size_}); + }); } } // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3/qwen3_attention.hpp b/csrc/models/qwen3/qwen3_attention.hpp index 0133a9339..f26f4c34e 100644 --- a/csrc/models/qwen3/qwen3_attention.hpp +++ b/csrc/models/qwen3/qwen3_attention.hpp @@ -25,7 +25,7 @@ class Qwen3Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; - void _initialize_preallocated_workspace(); + void _register_inference_buffer(); protected: std::shared_ptr qkv_proj_; @@ -49,9 +49,8 @@ class Qwen3Attention : public infinicore::nn::Module { INFINICORE_NN_PARAMETER(kv_cache_v_scale); size_t rank_qkv_output_size_; - - // preallocated workspace for Attention - infinicore::Tensor max_qkv_output_; + bool enable_workspace_manager_{false}; + infinicore::Tensor max_qkv_output_; // inference buffer for Attention infinicore::Tensor max_o_output_; }; } // namespace infinilm::models::qwen3 diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 3cecec1b9..dbe9baa6c 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -121,7 +121,6 @@ def __init__(self, config: EngineConfig): cache_config = PagedKVCacheConfig( num_blocks=config.num_blocks, block_size=config.block_size ) - self.scheduler = Scheduler( max_batch_size=config.max_batch_size, num_blocks=config.num_blocks, diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index aa61d07a7..b3852073c 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -98,9 +98,7 @@ def schedule(self) -> Optional[SchedulerOutput]: ) ) - num_tokens_this_step = ( - req.get_prompt_length() - req.num_cached_tokens - ) + num_tokens_this_step = req.get_prompt_length() - req.num_cached_tokens if ( current_num_batched_tokens + num_tokens_this_step >= self.max_num_batched_tokens From 2aea7dfaad6427cb0040a30cb472100752dd0873 Mon Sep 17 00:00:00 2001 From: MaYuhang <2902139028@qq.com> Date: Tue, 9 Jun 2026 09:56:45 +0000 Subject: [PATCH 5/6] refactor: improve WorkspaceManager buffer registration --- csrc/global_state/workspace_manager.hpp | 296 ++++++++++++------------ 1 file changed, 148 insertions(+), 148 deletions(-) diff --git a/csrc/global_state/workspace_manager.hpp b/csrc/global_state/workspace_manager.hpp index 628ba5cf7..e9daba6c8 100644 --- a/csrc/global_state/workspace_manager.hpp +++ b/csrc/global_state/workspace_manager.hpp @@ -11,151 +11,151 @@ namespace infinilm::global_state { -// /** -// * @brief Unified GPU inference workspace manager. -// * -// * Phase 1: modules register buffer layouts via ``register_buffer``. -// * Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views. -// */ -// class WorkspaceManager { -// public: -// using BindFn = std::function; - -// WorkspaceManager() = default; -// ~WorkspaceManager() = default; - -// /** -// * @brief Register a buffer appended at the current scratch_buffer tail. -// * -// * @param name Unique cache key; duplicate keys share one slot. -// * @param shape Tensor shape for the bound view. -// * @param dtype Element type of the bound view. -// * @param device Device on which scratch_buffer is allocated. -// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. -// */ -// void register_buffer(const std::string &name, -// const infinicore::Shape &shape, -// const infinicore::DataType &dtype, -// const infinicore::Device &device, -// BindFn bind_fn) { -// register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true); -// } - -// /** -// * @brief Register a buffer pinned at a fixed byte offset. -// * -// * @param name Unique cache key; duplicate keys share one slot. -// * @param offset Byte offset in scratch_buffer (currently only 0 is supported). -// * @param shape Tensor shape for the bound view. -// * @param dtype Element type of the bound view. -// * @param device Device on which scratch_buffer is allocated. -// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. -// */ -// void register_buffer(const std::string &name, -// size_t offset, -// const infinicore::Shape &shape, -// const infinicore::DataType &dtype, -// const infinicore::Device &device, -// BindFn bind_fn) { -// ASSERT(0 == offset); -// register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false); -// } - -// /** -// * @brief Allocate scratch_buffer and run all registered bind callbacks. -// * -// * @param device Device on which scratch_buffer is allocated. -// */ -// void finalize_and_bind(const infinicore::Device &device) { -// ASSERT(!finalized_); -// if (total_bytes_ == 0) { -// finalized_ = true; -// return; -// } - -// ASSERT(device.getType() != infinicore::Device::Type::CPU); - -// scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device); - -// spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0); - -// for (auto &[name, reg] : registrations_) { -// auto *base_ptr = scratch_buffer_->data() + reg.offset; -// auto view = infinicore::Tensor::from_blob(static_cast(base_ptr), reg.shape, reg.dtype, device); -// inference_buffers_[name] = view; -// for (auto &bind_fn : reg.bind_callbacks) { -// bind_fn(view); -// } -// } - -// finalized_ = true; -// } - -// private: -// /** @brief Metadata for one registered region in scratch_buffer. */ -// struct BufferRegistration { -// size_t offset{0}; -// size_t aligned_bytes{0}; -// infinicore::Shape shape; -// infinicore::DataType dtype; -// infinicore::Device device; -// std::vector bind_callbacks; -// }; - -// void register_buffer_impl(const std::string &name, -// size_t offset, -// const infinicore::Shape &shape, -// const infinicore::DataType &dtype, -// const infinicore::Device &device, -// BindFn bind_fn, -// bool bump_tail) { -// ASSERT(!finalized_); -// ASSERT(device.getType() != infinicore::Device::Type::CPU); - -// auto compute_numel = [](const infinicore::Shape &shape) { -// size_t numel = 1; -// for (const auto dim : shape) { -// numel *= dim; -// } -// return numel; -// }; - -// auto align_up = [](size_t n, size_t alignment = 512) { -// return (n + alignment - 1) & ~(alignment - 1); -// }; - -// const size_t actual_bytes = compute_numel(shape) * infinicore::dsize(dtype); -// const size_t aligned_bytes = align_up(actual_bytes); - -// if (registrations_.find(name) == registrations_.end()) { -// BufferRegistration reg; -// reg.offset = offset; -// reg.aligned_bytes = aligned_bytes; -// reg.shape = shape; -// reg.dtype = dtype; -// reg.device = device; - -// if (bump_tail) { -// total_bytes_ += aligned_bytes; -// } else { -// total_bytes_ = std::max(total_bytes_, offset + aligned_bytes); -// } -// registrations_.emplace(name, std::move(reg)); -// } - -// auto ® = registrations_.at(name); -// ASSERT(reg.aligned_bytes == aligned_bytes); -// ASSERT(reg.shape == shape); -// ASSERT(reg.dtype == dtype); -// ASSERT(reg.device == device); -// reg.bind_callbacks.push_back(std::move(bind_fn)); -// } - -// size_t total_bytes_{0}; -// bool finalized_{false}; -// infinicore::Tensor scratch_buffer_; -// std::unordered_map registrations_; -// std::unordered_map inference_buffers_; -// }; - -}; // namespace infinilm::global_state \ No newline at end of file +/** + * @brief Unified GPU inference workspace manager. + * + * Phase 1: modules register buffer layouts via ``register_buffer``. + * Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views. + */ +class WorkspaceManager { +public: + using BindFn = std::function; + + WorkspaceManager() = default; + ~WorkspaceManager() = default; + + /** + * @brief Register a buffer appended at the current scratch_buffer tail. + * + * @param name Unique cache key; duplicate keys share one slot. + * @param shape Tensor shape for the bound view. + * @param dtype Element type of the bound view. + * @param device Device on which scratch_buffer is allocated. + * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. + */ + void register_buffer(const std::string &name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device, + BindFn bind_fn) { + register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true); + } + + /** + * @brief Register a buffer pinned at a fixed byte offset. + * + * @param name Unique cache key; duplicate keys share one slot. + * @param offset Byte offset in scratch_buffer (currently only 0 is supported). + * @param shape Tensor shape for the bound view. + * @param dtype Element type of the bound view. + * @param device Device on which scratch_buffer is allocated. + * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. + */ + void register_buffer(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device, + BindFn bind_fn) { + ASSERT(0 == offset); + register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false); + } + + /** + * @brief Allocate scratch_buffer and run all registered bind callbacks. + * + * @param device Device on which scratch_buffer is allocated. + */ + void finalize_and_bind(const infinicore::Device &device) { + ASSERT(!finalized_); + if (total_bytes_ == 0) { + finalized_ = true; + return; + } + + ASSERT(device.getType() != infinicore::Device::Type::CPU); + + scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device); + + spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0); + + for (auto &[name, reg] : registrations_) { + auto *base_ptr = scratch_buffer_->data() + reg.offset; + auto view = infinicore::Tensor::from_blob(static_cast(base_ptr), reg.shape, reg.dtype, device); + inference_buffers_[name] = view; + for (auto &bind_fn : reg.bind_callbacks) { + bind_fn(view); + } + } + + finalized_ = true; + } + +private: + /** @brief Metadata for one registered region in scratch_buffer. */ + struct BufferRegistration { + size_t offset{0}; + size_t aligned_bytes{0}; + infinicore::Shape shape; + infinicore::DataType dtype; + infinicore::Device device; + std::vector bind_callbacks; + }; + + void register_buffer_impl(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device, + BindFn bind_fn, + bool bump_tail) { + ASSERT(!finalized_); + ASSERT(device.getType() != infinicore::Device::Type::CPU); + + auto compute_numel = [](const infinicore::Shape &shape) { + size_t numel = 1; + for (const auto dim : shape) { + numel *= dim; + } + return numel; + }; + + auto align_up = [](size_t n, size_t alignment = 512) { + return (n + alignment - 1) & ~(alignment - 1); + }; + + const size_t actual_bytes = compute_numel(shape) * infinicore::dsize(dtype); + const size_t aligned_bytes = align_up(actual_bytes); + + if (registrations_.find(name) == registrations_.end()) { + BufferRegistration reg; + reg.offset = offset; + reg.aligned_bytes = aligned_bytes; + reg.shape = shape; + reg.dtype = dtype; + reg.device = device; + + if (bump_tail) { + total_bytes_ += aligned_bytes; + } else { + total_bytes_ = std::max(total_bytes_, offset + aligned_bytes); + } + registrations_.emplace(name, std::move(reg)); + } + + auto ® = registrations_.at(name); + ASSERT(reg.aligned_bytes == aligned_bytes); + ASSERT(reg.shape == shape); + ASSERT(reg.dtype == dtype); + ASSERT(reg.device == device); + reg.bind_callbacks.push_back(std::move(bind_fn)); + } + + size_t total_bytes_{0}; + bool finalized_{false}; + infinicore::Tensor scratch_buffer_; + std::unordered_map registrations_; + std::unordered_map inference_buffers_; +}; + +}; // namespace infinilm::global_state From 6d1fa23f9d852bfc66ea84b036ecc1372102bbc3 Mon Sep 17 00:00:00 2001 From: wangpengcheng Date: Wed, 10 Jun 2026 08:11:04 +0000 Subject: [PATCH 6/6] issue/407 - refine the code --- csrc/engine/rank_worker.cpp | 4 +- csrc/global_state/workspace_manager.cpp | 278 ++++++++++++++++++ csrc/global_state/workspace_manager.hpp | 175 ++++------- csrc/layers/attention/attention.cpp | 23 +- csrc/layers/attention/attention.hpp | 2 - csrc/layers/attention/backends/flash_attn.cpp | 10 +- csrc/layers/attention/backends/flash_attn.hpp | 1 - csrc/layers/attention/backends/paged_attn.cpp | 10 +- csrc/layers/attention/backends/paged_attn.hpp | 1 - .../causal_lm_templates/text_causal_lm.hpp | 11 +- .../layers/causal_lm_templates/text_model.hpp | 11 +- csrc/layers/mlp/mlp.cpp | 29 +- csrc/layers/mlp/mlp.hpp | 4 - csrc/models/qwen3/qwen3_attention.cpp | 23 +- 14 files changed, 372 insertions(+), 210 deletions(-) create mode 100644 csrc/global_state/workspace_manager.cpp diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index e24475d72..1c52666f6 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -282,7 +282,8 @@ void RankWorker::thread_loop() { infinicore::context::syncStream(); if (infinilm_config_->enable_workspace_manager) { - forward_context_.workspace_manager.finalize_and_bind(rank_info_.device); + forward_context_.workspace_manager.finalize_and_bind(); + // forward_context_.workspace_manager.log_registrations(); } infinicore::context::syncStream(); @@ -402,6 +403,7 @@ void RankWorker::thread_loop() { try { { std::lock_guard lk(mutex_); + infinilm::global_state::get_forward_context().workspace_manager.reset_runtime_buffers(); infinicore::Tensor logits; // Try to get compiled graph diff --git a/csrc/global_state/workspace_manager.cpp b/csrc/global_state/workspace_manager.cpp new file mode 100644 index 000000000..b4ce84daf --- /dev/null +++ b/csrc/global_state/workspace_manager.cpp @@ -0,0 +1,278 @@ +#include "workspace_manager.hpp" + +#include "../utils.hpp" +#include "parallel_state.hpp" + +#include +#include +#include + +namespace infinilm::global_state { + +namespace { + +constexpr size_t k_scratch_align_bytes = 512; + +size_t compute_numel(const infinicore::Shape &shape) { + size_t numel = 1; + for (const auto dim : shape) { + numel *= dim; + } + return numel; +} + +size_t align_up(size_t n, size_t alignment = k_scratch_align_bytes) { + return (n + alignment - 1) & ~(alignment - 1); +} + +size_t compute_aligned_bytes(const infinicore::Shape &shape, const infinicore::DataType &dtype) { + return align_up(compute_numel(shape) * infinicore::dsize(dtype)); +} + +} // namespace + +void WorkspaceManager::register_buffer(const std::string &name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + _register_buffer_impl(name, total_bytes_, shape, dtype, device, true); +} + +void WorkspaceManager::register_buffer(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + ASSERT(0 == offset); + _register_buffer_impl(name, offset, shape, dtype, device, false); +} + +infinicore::Tensor WorkspaceManager::_make_runtime_view(size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + auto *base_ptr = scratch_buffer_->data() + offset; + return infinicore::Tensor::from_blob(static_cast(base_ptr), shape, dtype, device); +} + +infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + ASSERT(finalized_); + ASSERT(!scratch_buffer_.empty()); + + auto cached = runtime_buffers_.find(buffer_name); + if (cached != runtime_buffers_.end()) { + return cached->second; + } + + auto &rank_device = get_tensor_model_parallel_rank_info().device; + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); + + auto registered = registrations_.find(buffer_name); + if (registered != registrations_.end()) { + const auto ® = registered->second; + auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + return tensor; + } + + const size_t offset = scratch_buffer_offset_; + ASSERT(offset + aligned_bytes <= total_bytes_); + + auto tensor = _make_runtime_view(offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + scratch_buffer_offset_ += aligned_bytes; + return tensor; +} + +infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + ASSERT(finalized_); + ASSERT(!scratch_buffer_.empty()); + + auto cached = runtime_buffers_.find(buffer_name); + if (cached != runtime_buffers_.end()) { + return cached->second; + } + + auto &rank_device = get_tensor_model_parallel_rank_info().device; + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); + + auto registered = registrations_.find(buffer_name); + if (registered != registrations_.end()) { + const auto ® = registered->second; + auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + return tensor; + } + + ASSERT(offset + aligned_bytes <= total_bytes_); + + auto tensor = _make_runtime_view(offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + return tensor; +} + +void WorkspaceManager::reset_runtime_buffers() { + ASSERT(finalized_); + scratch_buffer_offset_ = 0; + runtime_buffers_.clear(); +} + +void WorkspaceManager::finalize_and_bind() { + ASSERT(!finalized_); + runtime_buffers_.clear(); + scratch_buffer_offset_ = 0; + + if (total_bytes_ == 0) { + finalized_ = true; + return; + } + + auto &rank_device = get_tensor_model_parallel_rank_info().device; + + scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, rank_device); + + spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0); + + for (auto &entry : registrations_) { + auto ® = entry.second; + auto *base_ptr = scratch_buffer_->data() + reg.offset; + ASSERT(rank_device == reg.device); + reg.bound_view = infinicore::Tensor::from_blob(static_cast(base_ptr), reg.shape, reg.dtype, rank_device); + } + + scratch_buffer_offset_ = 0; + finalized_ = true; +} + +void WorkspaceManager::log_registrations() const { + std::vector names; + names.reserve(registrations_.size()); + for (const auto &entry : registrations_) { + names.push_back(entry.first); + } + std::sort(names.begin(), names.end(), [this](const std::string &a, const std::string &b) { + return registrations_.at(a).offset < registrations_.at(b).offset; + }); + + std::ostringstream oss; + oss << std::fixed << std::setprecision(3); + oss << "\n========== WorkspaceManager registrations ==========\n"; + oss << " " << std::setw(16) << std::left << "finalized:" << finalized_ << "\n"; + oss << " " << std::setw(16) << std::left << "slots:" << registrations_.size() << "\n"; + oss << " " << std::setw(16) << std::left << "runtime_buffers:" << runtime_buffers_.size() << "\n"; + oss << " " << std::setw(16) << std::left << "scratch_bytes:" + << total_bytes_ << " (" << (total_bytes_ / 1024.0 / 1024.0) << " MB)\n"; + oss << " " << std::setw(16) << std::left << "scratch_buffer_offset_:" + << scratch_buffer_offset_ << " (" << (scratch_buffer_offset_ / 1024.0 / 1024.0) << " MB)\n"; + oss << " note: scratch_bytes=max span; registered slots may overlap.\n"; + oss << "----------------------------------------------------\n"; + + auto memory_end = [](const BufferRegistration ®) { + return reg.offset + reg.aligned_bytes; + }; + auto ranges_overlap = [](size_t a_start, size_t a_end, size_t b_start, size_t b_end) { + return a_start < b_end && b_start < a_end; + }; + + for (size_t slot_idx = 0; slot_idx < names.size(); ++slot_idx) { + const auto &name = names[slot_idx]; + const auto ® = registrations_.at(name); + const size_t mem_start = reg.offset; + const size_t mem_end = memory_end(reg); + + std::string shape_str = "["; + for (size_t i = 0; i < reg.shape.size(); ++i) { + if (i > 0) { + shape_str += ", "; + } + shape_str += std::to_string(reg.shape[i]); + } + shape_str += "]"; + + std::string overlap_str = "none"; + { + std::ostringstream overlap_oss; + bool first = true; + for (size_t other_idx = 0; other_idx < names.size(); ++other_idx) { + if (other_idx == slot_idx) { + continue; + } + const auto &other = registrations_.at(names[other_idx]); + if (ranges_overlap(mem_start, mem_end, other.offset, memory_end(other))) { + if (!first) { + overlap_oss << ", "; + } + overlap_oss << "slot " << other_idx; + first = false; + } + } + if (!first) { + overlap_str = overlap_oss.str(); + } + } + + oss << " [slot " << slot_idx << "]\n"; + oss << " " << std::setw(16) << std::left << "layout:" + << (reg.is_bump_tail ? "bump" : "pinned@0") << "\n"; + oss << " " << std::setw(16) << std::left << "memory:" + << "[" << mem_start << ", " << mem_end << ") " + << "(" << (reg.aligned_bytes / 1024.0 / 1024.0) << " MB)\n"; + oss << " " << std::setw(16) << std::left << "overlaps:" << overlap_str << "\n"; + oss << " " << std::setw(16) << std::left << "name:" << name << "\n"; + oss << " " << std::setw(16) << std::left << "shape:" << shape_str << "\n"; + oss << " " << std::setw(16) << std::left << "dtype:" << infinicore::toString(reg.dtype) << "\n"; + oss << " " << std::setw(16) << std::left << "device:" << reg.device.toString() << "\n"; + oss << " " << std::setw(16) << std::left << "bound:" << finalized_ << "\n"; + if (slot_idx + 1 < names.size()) { + oss << "\n"; + } + } + oss << "====================================================\n"; + + spdlog::info("{}", oss.str()); +} + +void WorkspaceManager::_register_buffer_impl(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device, + bool bump_tail) { + ASSERT(!finalized_); + ASSERT(device == get_tensor_model_parallel_rank_info().device); + + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); + + if (registrations_.find(name) == registrations_.end()) { + BufferRegistration reg; + reg.offset = offset; + reg.aligned_bytes = aligned_bytes; + reg.is_bump_tail = bump_tail; + reg.shape = shape; + reg.dtype = dtype; + reg.device = device; + + if (bump_tail) { + total_bytes_ += aligned_bytes; + } else { + total_bytes_ = std::max(total_bytes_, offset + aligned_bytes); + } + registrations_.emplace(name, std::move(reg)); + } + + auto ® = registrations_.at(name); + ASSERT(reg.is_bump_tail == bump_tail); + ASSERT(reg.aligned_bytes == aligned_bytes); + ASSERT(reg.shape == shape); + ASSERT(reg.dtype == dtype); + ASSERT(reg.device == device); +} + +} // namespace infinilm::global_state diff --git a/csrc/global_state/workspace_manager.hpp b/csrc/global_state/workspace_manager.hpp index e9daba6c8..301a63fb9 100644 --- a/csrc/global_state/workspace_manager.hpp +++ b/csrc/global_state/workspace_manager.hpp @@ -1,10 +1,7 @@ #pragma once #include "../models/infinilm_model.hpp" -#include "../utils.hpp" -#include -#include -#include +#include #include #include #include @@ -12,150 +9,82 @@ namespace infinilm::global_state { /** - * @brief Unified GPU inference workspace manager. + * @brief Unified GPU inference scratch buffer. * - * Phase 1: modules register buffer layouts via ``register_buffer``. - * Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views. + * Flow: register_buffer -> finalize_and_bind -> get_buffer (named cache) -> log_registrations. + * get_buffer looks up buffer_name; on miss bump-allocates and caches for reuse across layers. */ class WorkspaceManager { public: - using BindFn = std::function; - WorkspaceManager() = default; ~WorkspaceManager() = default; - /** - * @brief Register a buffer appended at the current scratch_buffer tail. - * - * @param name Unique cache key; duplicate keys share one slot. - * @param shape Tensor shape for the bound view. - * @param dtype Element type of the bound view. - * @param device Device on which scratch_buffer is allocated. - * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. - */ + /** @brief Register a bump slot at current total_bytes_. Same name reuses one slot. */ void register_buffer(const std::string &name, const infinicore::Shape &shape, const infinicore::DataType &dtype, - const infinicore::Device &device, - BindFn bind_fn) { - register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true); - } - - /** - * @brief Register a buffer pinned at a fixed byte offset. - * - * @param name Unique cache key; duplicate keys share one slot. - * @param offset Byte offset in scratch_buffer (currently only 0 is supported). - * @param shape Tensor shape for the bound view. - * @param dtype Element type of the bound view. - * @param device Device on which scratch_buffer is allocated. - * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view. - */ + const infinicore::Device &device); + + /** @brief Register a pinned@0 slot (only offset==0). May overlap bump slots. */ void register_buffer(const std::string &name, size_t offset, const infinicore::Shape &shape, const infinicore::DataType &dtype, - const infinicore::Device &device, - BindFn bind_fn) { - ASSERT(0 == offset); - register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false); - } - - /** - * @brief Allocate scratch_buffer and run all registered bind callbacks. - * - * @param device Device on which scratch_buffer is allocated. - */ - void finalize_and_bind(const infinicore::Device &device) { - ASSERT(!finalized_); - if (total_bytes_ == 0) { - finalized_ = true; - return; - } - - ASSERT(device.getType() != infinicore::Device::Type::CPU); - - scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device); - - spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0); - - for (auto &[name, reg] : registrations_) { - auto *base_ptr = scratch_buffer_->data() + reg.offset; - auto view = infinicore::Tensor::from_blob(static_cast(base_ptr), reg.shape, reg.dtype, device); - inference_buffers_[name] = view; - for (auto &bind_fn : reg.bind_callbacks) { - bind_fn(view); - } - } - - finalized_ = true; - } + const infinicore::Device &device); + + /** @brief Return a cached view by name, or bump-allocate and cache on first use. */ + infinicore::Tensor get_buffer(const std::string &buffer_name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + /** @brief Return a cached view by name, or bind at offset and cache on first use. */ + infinicore::Tensor get_buffer(const std::string &buffer_name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + /** @brief Allocate scratch_buffer_. */ + void finalize_and_bind(); + + /** @brief Reset runtime bump offset to 0. Call at the start of each forward. */ + void reset_runtime_buffers(); + + /** @brief Log slot layout with memory ranges and overlap info. */ + void log_registrations() const; private: - /** @brief Metadata for one registered region in scratch_buffer. */ struct BufferRegistration { - size_t offset{0}; - size_t aligned_bytes{0}; + size_t offset{0}; // view start in scratch_buffer_ + size_t aligned_bytes{0}; // aligned byte span + bool is_bump_tail{true}; // bump vs pinned@0 infinicore::Shape shape; infinicore::DataType dtype; infinicore::Device device; - std::vector bind_callbacks; + infinicore::Tensor bound_view; // set in finalize_and_bind }; - void register_buffer_impl(const std::string &name, - size_t offset, - const infinicore::Shape &shape, - const infinicore::DataType &dtype, - const infinicore::Device &device, - BindFn bind_fn, - bool bump_tail) { - ASSERT(!finalized_); - ASSERT(device.getType() != infinicore::Device::Type::CPU); - - auto compute_numel = [](const infinicore::Shape &shape) { - size_t numel = 1; - for (const auto dim : shape) { - numel *= dim; - } - return numel; - }; - - auto align_up = [](size_t n, size_t alignment = 512) { - return (n + alignment - 1) & ~(alignment - 1); - }; - - const size_t actual_bytes = compute_numel(shape) * infinicore::dsize(dtype); - const size_t aligned_bytes = align_up(actual_bytes); - - if (registrations_.find(name) == registrations_.end()) { - BufferRegistration reg; - reg.offset = offset; - reg.aligned_bytes = aligned_bytes; - reg.shape = shape; - reg.dtype = dtype; - reg.device = device; - - if (bump_tail) { - total_bytes_ += aligned_bytes; - } else { - total_bytes_ = std::max(total_bytes_, offset + aligned_bytes); - } - registrations_.emplace(name, std::move(reg)); - } - - auto ® = registrations_.at(name); - ASSERT(reg.aligned_bytes == aligned_bytes); - ASSERT(reg.shape == shape); - ASSERT(reg.dtype == dtype); - ASSERT(reg.device == device); - reg.bind_callbacks.push_back(std::move(bind_fn)); - } + void _register_buffer_impl(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device, + bool bump_tail); + + infinicore::Tensor _make_runtime_view(size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); - size_t total_bytes_{0}; bool finalized_{false}; + infinicore::Tensor scratch_buffer_; + size_t total_bytes_{0}; + size_t scratch_buffer_offset_{0}; + std::unordered_map registrations_; - std::unordered_map inference_buffers_; + std::unordered_map runtime_buffers_; }; -}; // namespace infinilm::global_state +} // namespace infinilm::global_state diff --git a/csrc/layers/attention/attention.cpp b/csrc/layers/attention/attention.cpp index 747e980d3..7dc9a5c25 100644 --- a/csrc/layers/attention/attention.cpp +++ b/csrc/layers/attention/attention.cpp @@ -74,7 +74,8 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position // 1. Project Q, K, V infinicore::Tensor q, k, v; if (enable_workspace_manager_) { - auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Attention_qkv_output", {batch_size, seq_len, rank_qkv_output_size_}, dtype_, device_); std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); } else { std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); @@ -107,7 +108,8 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position // 6. Project output if (enable_workspace_manager_) { - auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Attention_o_output", {batch_size, seq_len, hidden_size_}, dtype_, device_); o_proj_->forward_(o_output, attn_output); return o_output; } @@ -128,7 +130,8 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ // 1. Project Q, K, V infinicore::Tensor q, k, v; if (enable_workspace_manager_) { - auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Attention_qkv_output", {1, seq_len, rank_qkv_output_size_}, dtype_, device_); std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); } else { std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); @@ -160,7 +163,8 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ // 6. Project output if (enable_workspace_manager_) { - auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Attention_o_output", {1, seq_len, hidden_size_}, dtype_, device_); o_proj_->forward_(o_output, attn_output); return o_output; } @@ -187,16 +191,7 @@ void Attention::_register_inference_buffer() { attention_cache_key, attention_buffer_shape, dtype_, - device_, - [this, max_num_batched_tokens, max_output_size](const infinicore::Tensor &attention_buffer) { - const auto attention_buffer_shape = attention_buffer->shape(); - ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); - - max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}}) - ->view({max_num_batched_tokens, rank_qkv_output_size_}); - max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}}) - ->view({max_num_batched_tokens, hidden_size_}); - }); + device_); } void init_kv_cache_quant_params(std::function register_fn, diff --git a/csrc/layers/attention/attention.hpp b/csrc/layers/attention/attention.hpp index 00cab342b..7814f7307 100644 --- a/csrc/layers/attention/attention.hpp +++ b/csrc/layers/attention/attention.hpp @@ -61,8 +61,6 @@ class Attention : public infinicore::nn::Module { private: bool enable_workspace_manager_{false}; size_t rank_qkv_output_size_{0}; - infinicore::Tensor max_qkv_output_; // inference buffer for Attention - infinicore::Tensor max_o_output_; // inference buffer for Attention }; void init_kv_cache_quant_params(std::function register_fn, const infinicore::Device &device, diff --git a/csrc/layers/attention/backends/flash_attn.cpp b/csrc/layers/attention/backends/flash_attn.cpp index 12a9e4116..14b94b548 100644 --- a/csrc/layers/attention/backends/flash_attn.cpp +++ b/csrc/layers/attention/backends/flash_attn.cpp @@ -63,7 +63,8 @@ infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, infinicore::Tensor attn_output; if (is_prefill) { if (enable_workspace_manager_) { - attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + attn_output = workspace_manager.get_buffer("FlashAttention_attn_output", {seq_len, num_heads_, head_dim_}, dtype_, device_); } else { attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); } @@ -133,11 +134,6 @@ void FlashAttentionImpl::_register_inference_buffer() { cache_key, flash_attn_buffer_shape, dtype_, - device_, - [this, max_num_batched_tokens](const infinicore::Tensor &flash_attention_impl_buffer) { - const auto buffer_shape = flash_attention_impl_buffer->shape(); - ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); - max_attn_output_ = flash_attention_impl_buffer; - }); + device_); } } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_attn.hpp b/csrc/layers/attention/backends/flash_attn.hpp index 334e39a47..1f5130158 100644 --- a/csrc/layers/attention/backends/flash_attn.hpp +++ b/csrc/layers/attention/backends/flash_attn.hpp @@ -46,7 +46,6 @@ class FlashAttentionImpl { private: void _register_inference_buffer(); bool enable_workspace_manager_{false}; - infinicore::Tensor max_attn_output_; // inference buffer for FlashAttentionImpl size_t num_heads_; size_t head_size_; diff --git a/csrc/layers/attention/backends/paged_attn.cpp b/csrc/layers/attention/backends/paged_attn.cpp index b97788881..7d6eec5d5 100644 --- a/csrc/layers/attention/backends/paged_attn.cpp +++ b/csrc/layers/attention/backends/paged_attn.cpp @@ -57,7 +57,8 @@ infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, // 2. Compute attention infinicore::Tensor attn_output; if (enable_workspace_manager_) { - attn_output = max_attn_output_->narrow({{0, 0, seq_len}}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + attn_output = workspace_manager.get_buffer("PagedAttention_attn_output", {seq_len, num_heads_, head_dim_}, dtype_, device_); } else { attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); } @@ -121,12 +122,7 @@ void PagedAttentionImpl::_register_inference_buffer() { cache_key, paged_attn_buffer_shape, dtype_, - device_, - [this, max_num_batched_tokens](const infinicore::Tensor &paged_attention_impl_buffer) { - const auto buffer_shape = paged_attention_impl_buffer->shape(); - ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_); - max_attn_output_ = paged_attention_impl_buffer; - }); + device_); } } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.hpp b/csrc/layers/attention/backends/paged_attn.hpp index f5c625b45..e1ab7688e 100644 --- a/csrc/layers/attention/backends/paged_attn.hpp +++ b/csrc/layers/attention/backends/paged_attn.hpp @@ -46,7 +46,6 @@ class PagedAttentionImpl { private: void _register_inference_buffer(); bool enable_workspace_manager_{false}; - infinicore::Tensor max_attn_output_; // inference buffer for PagedAttentionImpl size_t num_heads_; size_t head_size_; diff --git a/csrc/layers/causal_lm_templates/text_causal_lm.hpp b/csrc/layers/causal_lm_templates/text_causal_lm.hpp index 0f4a578b4..86ce8b1c6 100644 --- a/csrc/layers/causal_lm_templates/text_causal_lm.hpp +++ b/csrc/layers/causal_lm_templates/text_causal_lm.hpp @@ -60,7 +60,8 @@ class TextCausalLM : public InfinilmModel { const auto shape = hidden_states->shape(); const size_t bs = shape[0]; const size_t seq_len = shape[1]; - logits = max_logits_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, vocab_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + logits = workspace_manager.get_buffer("TextCausalLM_logits", 0, {bs, seq_len, vocab_size_}, dtype_, device_); lm_head_->forward_(logits, hidden_states); } else { logits = lm_head_->forward(hidden_states); @@ -102,16 +103,10 @@ class TextCausalLM : public InfinilmModel { 0, logits_shape, dtype_, - device_, - [this, max_num_batched_tokens](const infinicore::Tensor &logits_buffer) { - const auto logits_buffer_shape = logits_buffer->shape(); - ASSERT(logits_buffer_shape[0] == max_num_batched_tokens && logits_buffer_shape[1] == vocab_size_); - max_logits_ = logits_buffer; - }); + device_); } bool enable_workspace_manager_{false}; - infinicore::Tensor max_logits_; // inference buffer for TextCausalLM }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/causal_lm_templates/text_model.hpp b/csrc/layers/causal_lm_templates/text_model.hpp index d615920df..376a7f0c5 100644 --- a/csrc/layers/causal_lm_templates/text_model.hpp +++ b/csrc/layers/causal_lm_templates/text_model.hpp @@ -61,7 +61,8 @@ class TextModel : public infinicore::nn::Module { const auto shape = input_ids->shape(); const size_t bs = shape[0]; const size_t seq_len = shape[1]; - hidden_states = max_hidden_states_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + hidden_states = workspace_manager.get_buffer("TextModel_hidden_states", {bs, seq_len, hidden_size_}, dtype_, device_); embed_tokens_->forward_(hidden_states, input_ids); } else { hidden_states = embed_tokens_->forward(input_ids); @@ -136,12 +137,7 @@ class TextModel : public infinicore::nn::Module { text_model_cache_key, hidden_states_shape, dtype_, - device_, - [this, max_num_batched_tokens](const infinicore::Tensor &hidden_states_buffer) { - const auto hidden_states_buffer_shape = hidden_states_buffer->shape(); - ASSERT(hidden_states_buffer_shape[0] == max_num_batched_tokens && hidden_states_buffer_shape[1] == hidden_size_); - max_hidden_states_ = hidden_states_buffer; - }); + device_); } protected: @@ -156,7 +152,6 @@ class TextModel : public infinicore::nn::Module { private: bool enable_workspace_manager_{false}; - infinicore::Tensor max_hidden_states_; // inference buffer for TextModel }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/mlp/mlp.cpp b/csrc/layers/mlp/mlp.cpp index 13d0ac581..095bd62a1 100644 --- a/csrc/layers/mlp/mlp.cpp +++ b/csrc/layers/mlp/mlp.cpp @@ -53,21 +53,24 @@ infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { } infinicore::Tensor MLP::_forward_with_inference_buffer(const infinicore::Tensor &hidden_states) const { + + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const auto shape = hidden_states->shape(); const size_t bs = shape[0]; const size_t seq_len = shape[1]; - // 1. Project to gate and up auto hidden_states_mutable = hidden_states; - auto gate_up_output = max_gate_up_output_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, rank_gate_up_output_size_}); + // 1. Project to gate and up + auto gate_up_output = workspace_manager.get_buffer("MLP_gate_up_output", {bs, seq_len, rank_gate_up_output_size_}, dtype_, device_); auto [gate, up] = gate_up_proj_->forward_split_(gate_up_output, hidden_states_mutable); // 2. Apply SwiGLU: silu(gate) * up - auto intermediate = max_intermediate_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, rank_intermediate_size_}); + auto intermediate = workspace_manager.get_buffer("MLP_intermediate", {bs, seq_len, rank_intermediate_size_}, dtype_, device_); infinicore::op::swiglu_(intermediate, up, gate); // 3. Project down - auto down_output = max_down_output_->narrow({{0, 0, bs * seq_len}})->view({bs, seq_len, hidden_size_}); + auto down_output = workspace_manager.get_buffer("MLP_down_output", {bs, seq_len, hidden_size_}, dtype_, device_); down_proj_->forward_(down_output, intermediate); return down_output; } @@ -87,7 +90,7 @@ void MLP::_register_inference_buffer() { + infinicore::toString(dtype_) + "_device_" + device_.toString(); - auto align_up = [](size_t n, size_t alignment = 256) { + auto align_up = [](size_t n, size_t alignment = 512) { return (n + alignment - 1) & ~(alignment - 1); }; @@ -100,21 +103,7 @@ void MLP::_register_inference_buffer() { mlp_cache_key, mlp_buffer_shape, dtype_, - device_, - [this, max_num_batched_tokens, rank_gate_up_output_size_aligned, rank_intermediate_size_aligned, - max_output_size](const infinicore::Tensor &mlp_buffer) { - const auto buffer_shape = mlp_buffer->shape(); - ASSERT(buffer_shape[0] == max_num_batched_tokens * max_output_size); - - max_gate_up_output_ = mlp_buffer->narrow({{0, 0, max_num_batched_tokens * rank_gate_up_output_size_}}) - ->view({max_num_batched_tokens, rank_gate_up_output_size_}); - max_intermediate_ = mlp_buffer->narrow({{0, max_num_batched_tokens * rank_gate_up_output_size_aligned, - max_num_batched_tokens * rank_intermediate_size_}}) - ->view({max_num_batched_tokens, rank_intermediate_size_}); - max_down_output_ = mlp_buffer->narrow({{0, max_num_batched_tokens * rank_intermediate_size_aligned, - max_num_batched_tokens * hidden_size_}}) - ->view({max_num_batched_tokens, hidden_size_}); - }); + device_); } } // namespace infinilm::layers::mlp diff --git a/csrc/layers/mlp/mlp.hpp b/csrc/layers/mlp/mlp.hpp index 3887a2f23..fb15c9ecb 100644 --- a/csrc/layers/mlp/mlp.hpp +++ b/csrc/layers/mlp/mlp.hpp @@ -63,10 +63,6 @@ class MLP : public infinicore::nn::Module { bool enable_workspace_manager_{false}; size_t rank_gate_up_output_size_{0}; size_t rank_intermediate_size_{0}; - - infinicore::Tensor max_gate_up_output_; // inference buffer for MLP - infinicore::Tensor max_intermediate_; - infinicore::Tensor max_down_output_; }; } // namespace infinilm::layers::mlp diff --git a/csrc/models/qwen3/qwen3_attention.cpp b/csrc/models/qwen3/qwen3_attention.cpp index 9f806d967..745d877ea 100644 --- a/csrc/models/qwen3/qwen3_attention.cpp +++ b/csrc/models/qwen3/qwen3_attention.cpp @@ -81,7 +81,8 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos // 1. Project Q, K, V infinicore::Tensor q, k, v; if (enable_workspace_manager_) { - auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Qwen3Attention_qkv_output", {batch_size, seq_len, rank_qkv_output_size_}, dtype_, device_); std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); } else { std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); @@ -117,7 +118,8 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos // 6. Project output if (enable_workspace_manager_) { - auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Qwen3Attention_o_output", {batch_size, seq_len, hidden_size_}, dtype_, device_); o_proj_->forward_(o_output, attn_output); return o_output; } @@ -139,7 +141,8 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi // 1. Project Q, K, V infinicore::Tensor q, k, v; if (enable_workspace_manager_) { - auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Qwen3Attention_qkv_output", {1, seq_len, rank_qkv_output_size_}, dtype_, device_); std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); } else { std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); @@ -173,7 +176,8 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi // 6. Project output if (enable_workspace_manager_) { - auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_}); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Qwen3Attention_o_output", {1, seq_len, hidden_size_}, dtype_, device_); o_proj_->forward_(o_output, attn_output); return o_output; } @@ -200,16 +204,7 @@ void Qwen3Attention::_register_inference_buffer() { attention_cache_key, attention_buffer_shape, dtype_, - device_, - [this, max_num_batched_tokens, max_output_size](const infinicore::Tensor &attention_buffer) { - const auto attention_buffer_shape = attention_buffer->shape(); - ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size); - - max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}}) - ->view({max_num_batched_tokens, rank_qkv_output_size_}); - max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}}) - ->view({max_num_batched_tokens, hidden_size_}); - }); + device_); } } // namespace infinilm::models::qwen3