diff --git a/csrc/engine/compiler/paged_compiler.cpp b/csrc/engine/compiler/paged_compiler.cpp index de6ec5d1..392ca5e6 100644 --- a/csrc/engine/compiler/paged_compiler.cpp +++ b/csrc/engine/compiler/paged_compiler.cpp @@ -109,7 +109,11 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input & graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); auto graph = std::get<0>(result->second.compiled); - auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + // Reuse the GraphTensor output captured at compile time. + // Do not call resume_from_blob_() on workspace-backed logits: + // that registers a second deleter on the same GPU block and + // triggers double free in PinnableBlockAllocator. + auto shared_output = std::get<1>(result->second.compiled); return std::make_tuple(graph, shared_output); } diff --git a/csrc/engine/compiler/static_batching_compiler.cpp b/csrc/engine/compiler/static_batching_compiler.cpp index dcd7f714..637b3e79 100644 --- a/csrc/engine/compiler/static_batching_compiler.cpp +++ b/csrc/engine/compiler/static_batching_compiler.cpp @@ -56,7 +56,11 @@ StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled( graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); auto graph = std::get<0>(result->second.compiled); - auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + // Reuse the GraphTensor output captured at compile time. + // Do not call resume_from_blob_() on workspace-backed logits: + // that registers a second deleter on the same GPU block and + // triggers double free in PinnableBlockAllocator. + auto shared_output = std::get<1>(result->second.compiled); return std::make_tuple(graph, shared_output); } } else { diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 77d07df4..44196de3 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -15,7 +15,8 @@ InferEngine::InferEngine( const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend, - std::optional kv_cache_dtype) // Changed parameter + std::optional kv_cache_dtype, // Changed parameter + size_t max_num_batched_tokens) : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); @@ -23,7 +24,7 @@ InferEngine::InferEngine( // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = infinilm::config::ConfigFactory::createConfig(config_str); - auto infinilm_config = std::make_shared(attention_backend, this->model_config_); + auto infinilm_config = std::make_shared(attention_backend, this->model_config_, max_num_batched_tokens); // Only support offline int8 kv cache quantization in this version if (kv_cache_dtype.has_value()) { diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 70c3c164..5aa72252 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -28,7 +28,8 @@ class InferEngine { const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, - std::optional kv_cache_dtype = std::nullopt); + std::optional kv_cache_dtype = std::nullopt, + size_t max_num_batched_tokens = 2048); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 87568fd6..1c52666f 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -278,6 +278,15 @@ void RankWorker::thread_loop() { if (!model_) { throw std::runtime_error("Failed to create model"); } + + infinicore::context::syncStream(); + + if (infinilm_config_->enable_workspace_manager) { + forward_context_.workspace_manager.finalize_and_bind(); + // forward_context_.workspace_manager.log_registrations(); + } + infinicore::context::syncStream(); + if (enable_graph_compiling_) { compiler_ = std::make_unique(model_, barrier_); } @@ -394,6 +403,7 @@ void RankWorker::thread_loop() { try { { std::lock_guard lk(mutex_); + infinilm::global_state::get_forward_context().workspace_manager.reset_runtime_buffers(); infinicore::Tensor logits; // Try to get compiled graph diff --git a/csrc/global_state/forward_context.hpp b/csrc/global_state/forward_context.hpp index 2568fc7e..b5654800 100644 --- a/csrc/global_state/forward_context.hpp +++ b/csrc/global_state/forward_context.hpp @@ -1,6 +1,9 @@ #pragma once #include "../models/infinilm_model.hpp" +#include "../utils.hpp" +#include "workspace_manager.hpp" +#include namespace infinilm::global_state { @@ -48,6 +51,7 @@ struct ForwardContext { AttentionMetadata attn_metadata; MultiModalMetadata mm_metadata; std::vector kv_cache_vec; + WorkspaceManager workspace_manager; }; void initialize_forward_context(ForwardContext &forward_context); diff --git a/csrc/global_state/infinilm_config.hpp b/csrc/global_state/infinilm_config.hpp index 9b80706c..7e39c3f2 100644 --- a/csrc/global_state/infinilm_config.hpp +++ b/csrc/global_state/infinilm_config.hpp @@ -14,13 +14,24 @@ struct InfinilmConfig { public: InfinilmConfig() = default; InfinilmConfig(const infinilm::backends::AttentionBackend &backend, - const std::shared_ptr &model_config) + const std::shared_ptr &model_config, + size_t max_num_batched_tokens) : attention_backend(backend), - model_config(model_config) {} + model_config(model_config), + max_num_batched_tokens(max_num_batched_tokens) { + + if (max_num_batched_tokens > 0) { + const size_t max_position_embeddings = model_config->get("max_position_embeddings"); + ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings); + enable_workspace_manager = true; + } + } public: infinilm::backends::AttentionBackend attention_backend; std::shared_ptr model_config; + size_t max_num_batched_tokens = 0; + bool enable_workspace_manager{false}; }; /** diff --git a/csrc/global_state/workspace_manager.cpp b/csrc/global_state/workspace_manager.cpp new file mode 100644 index 00000000..b4ce84da --- /dev/null +++ b/csrc/global_state/workspace_manager.cpp @@ -0,0 +1,278 @@ +#include "workspace_manager.hpp" + +#include "../utils.hpp" +#include "parallel_state.hpp" + +#include +#include +#include + +namespace infinilm::global_state { + +namespace { + +constexpr size_t k_scratch_align_bytes = 512; + +size_t compute_numel(const infinicore::Shape &shape) { + size_t numel = 1; + for (const auto dim : shape) { + numel *= dim; + } + return numel; +} + +size_t align_up(size_t n, size_t alignment = k_scratch_align_bytes) { + return (n + alignment - 1) & ~(alignment - 1); +} + +size_t compute_aligned_bytes(const infinicore::Shape &shape, const infinicore::DataType &dtype) { + return align_up(compute_numel(shape) * infinicore::dsize(dtype)); +} + +} // namespace + +void WorkspaceManager::register_buffer(const std::string &name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + _register_buffer_impl(name, total_bytes_, shape, dtype, device, true); +} + +void WorkspaceManager::register_buffer(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + ASSERT(0 == offset); + _register_buffer_impl(name, offset, shape, dtype, device, false); +} + +infinicore::Tensor WorkspaceManager::_make_runtime_view(size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + auto *base_ptr = scratch_buffer_->data() + offset; + return infinicore::Tensor::from_blob(static_cast(base_ptr), shape, dtype, device); +} + +infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + ASSERT(finalized_); + ASSERT(!scratch_buffer_.empty()); + + auto cached = runtime_buffers_.find(buffer_name); + if (cached != runtime_buffers_.end()) { + return cached->second; + } + + auto &rank_device = get_tensor_model_parallel_rank_info().device; + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); + + auto registered = registrations_.find(buffer_name); + if (registered != registrations_.end()) { + const auto ® = registered->second; + auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + return tensor; + } + + const size_t offset = scratch_buffer_offset_; + ASSERT(offset + aligned_bytes <= total_bytes_); + + auto tensor = _make_runtime_view(offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + scratch_buffer_offset_ += aligned_bytes; + return tensor; +} + +infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + ASSERT(finalized_); + ASSERT(!scratch_buffer_.empty()); + + auto cached = runtime_buffers_.find(buffer_name); + if (cached != runtime_buffers_.end()) { + return cached->second; + } + + auto &rank_device = get_tensor_model_parallel_rank_info().device; + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); + + auto registered = registrations_.find(buffer_name); + if (registered != registrations_.end()) { + const auto ® = registered->second; + auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + return tensor; + } + + ASSERT(offset + aligned_bytes <= total_bytes_); + + auto tensor = _make_runtime_view(offset, shape, dtype, rank_device); + runtime_buffers_.emplace(buffer_name, tensor); + return tensor; +} + +void WorkspaceManager::reset_runtime_buffers() { + ASSERT(finalized_); + scratch_buffer_offset_ = 0; + runtime_buffers_.clear(); +} + +void WorkspaceManager::finalize_and_bind() { + ASSERT(!finalized_); + runtime_buffers_.clear(); + scratch_buffer_offset_ = 0; + + if (total_bytes_ == 0) { + finalized_ = true; + return; + } + + auto &rank_device = get_tensor_model_parallel_rank_info().device; + + scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, rank_device); + + spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0); + + for (auto &entry : registrations_) { + auto ® = entry.second; + auto *base_ptr = scratch_buffer_->data() + reg.offset; + ASSERT(rank_device == reg.device); + reg.bound_view = infinicore::Tensor::from_blob(static_cast(base_ptr), reg.shape, reg.dtype, rank_device); + } + + scratch_buffer_offset_ = 0; + finalized_ = true; +} + +void WorkspaceManager::log_registrations() const { + std::vector names; + names.reserve(registrations_.size()); + for (const auto &entry : registrations_) { + names.push_back(entry.first); + } + std::sort(names.begin(), names.end(), [this](const std::string &a, const std::string &b) { + return registrations_.at(a).offset < registrations_.at(b).offset; + }); + + std::ostringstream oss; + oss << std::fixed << std::setprecision(3); + oss << "\n========== WorkspaceManager registrations ==========\n"; + oss << " " << std::setw(16) << std::left << "finalized:" << finalized_ << "\n"; + oss << " " << std::setw(16) << std::left << "slots:" << registrations_.size() << "\n"; + oss << " " << std::setw(16) << std::left << "runtime_buffers:" << runtime_buffers_.size() << "\n"; + oss << " " << std::setw(16) << std::left << "scratch_bytes:" + << total_bytes_ << " (" << (total_bytes_ / 1024.0 / 1024.0) << " MB)\n"; + oss << " " << std::setw(16) << std::left << "scratch_buffer_offset_:" + << scratch_buffer_offset_ << " (" << (scratch_buffer_offset_ / 1024.0 / 1024.0) << " MB)\n"; + oss << " note: scratch_bytes=max span; registered slots may overlap.\n"; + oss << "----------------------------------------------------\n"; + + auto memory_end = [](const BufferRegistration ®) { + return reg.offset + reg.aligned_bytes; + }; + auto ranges_overlap = [](size_t a_start, size_t a_end, size_t b_start, size_t b_end) { + return a_start < b_end && b_start < a_end; + }; + + for (size_t slot_idx = 0; slot_idx < names.size(); ++slot_idx) { + const auto &name = names[slot_idx]; + const auto ® = registrations_.at(name); + const size_t mem_start = reg.offset; + const size_t mem_end = memory_end(reg); + + std::string shape_str = "["; + for (size_t i = 0; i < reg.shape.size(); ++i) { + if (i > 0) { + shape_str += ", "; + } + shape_str += std::to_string(reg.shape[i]); + } + shape_str += "]"; + + std::string overlap_str = "none"; + { + std::ostringstream overlap_oss; + bool first = true; + for (size_t other_idx = 0; other_idx < names.size(); ++other_idx) { + if (other_idx == slot_idx) { + continue; + } + const auto &other = registrations_.at(names[other_idx]); + if (ranges_overlap(mem_start, mem_end, other.offset, memory_end(other))) { + if (!first) { + overlap_oss << ", "; + } + overlap_oss << "slot " << other_idx; + first = false; + } + } + if (!first) { + overlap_str = overlap_oss.str(); + } + } + + oss << " [slot " << slot_idx << "]\n"; + oss << " " << std::setw(16) << std::left << "layout:" + << (reg.is_bump_tail ? "bump" : "pinned@0") << "\n"; + oss << " " << std::setw(16) << std::left << "memory:" + << "[" << mem_start << ", " << mem_end << ") " + << "(" << (reg.aligned_bytes / 1024.0 / 1024.0) << " MB)\n"; + oss << " " << std::setw(16) << std::left << "overlaps:" << overlap_str << "\n"; + oss << " " << std::setw(16) << std::left << "name:" << name << "\n"; + oss << " " << std::setw(16) << std::left << "shape:" << shape_str << "\n"; + oss << " " << std::setw(16) << std::left << "dtype:" << infinicore::toString(reg.dtype) << "\n"; + oss << " " << std::setw(16) << std::left << "device:" << reg.device.toString() << "\n"; + oss << " " << std::setw(16) << std::left << "bound:" << finalized_ << "\n"; + if (slot_idx + 1 < names.size()) { + oss << "\n"; + } + } + oss << "====================================================\n"; + + spdlog::info("{}", oss.str()); +} + +void WorkspaceManager::_register_buffer_impl(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device, + bool bump_tail) { + ASSERT(!finalized_); + ASSERT(device == get_tensor_model_parallel_rank_info().device); + + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); + + if (registrations_.find(name) == registrations_.end()) { + BufferRegistration reg; + reg.offset = offset; + reg.aligned_bytes = aligned_bytes; + reg.is_bump_tail = bump_tail; + reg.shape = shape; + reg.dtype = dtype; + reg.device = device; + + if (bump_tail) { + total_bytes_ += aligned_bytes; + } else { + total_bytes_ = std::max(total_bytes_, offset + aligned_bytes); + } + registrations_.emplace(name, std::move(reg)); + } + + auto ® = registrations_.at(name); + ASSERT(reg.is_bump_tail == bump_tail); + ASSERT(reg.aligned_bytes == aligned_bytes); + ASSERT(reg.shape == shape); + ASSERT(reg.dtype == dtype); + ASSERT(reg.device == device); +} + +} // namespace infinilm::global_state diff --git a/csrc/global_state/workspace_manager.hpp b/csrc/global_state/workspace_manager.hpp new file mode 100644 index 00000000..301a63fb --- /dev/null +++ b/csrc/global_state/workspace_manager.hpp @@ -0,0 +1,90 @@ +#pragma once + +#include "../models/infinilm_model.hpp" +#include +#include +#include +#include + +namespace infinilm::global_state { + +/** + * @brief Unified GPU inference scratch buffer. + * + * Flow: register_buffer -> finalize_and_bind -> get_buffer (named cache) -> log_registrations. + * get_buffer looks up buffer_name; on miss bump-allocates and caches for reuse across layers. + */ +class WorkspaceManager { +public: + WorkspaceManager() = default; + ~WorkspaceManager() = default; + + /** @brief Register a bump slot at current total_bytes_. Same name reuses one slot. */ + void register_buffer(const std::string &name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + /** @brief Register a pinned@0 slot (only offset==0). May overlap bump slots. */ + void register_buffer(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + /** @brief Return a cached view by name, or bump-allocate and cache on first use. */ + infinicore::Tensor get_buffer(const std::string &buffer_name, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + /** @brief Return a cached view by name, or bind at offset and cache on first use. */ + infinicore::Tensor get_buffer(const std::string &buffer_name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + /** @brief Allocate scratch_buffer_. */ + void finalize_and_bind(); + + /** @brief Reset runtime bump offset to 0. Call at the start of each forward. */ + void reset_runtime_buffers(); + + /** @brief Log slot layout with memory ranges and overlap info. */ + void log_registrations() const; + +private: + struct BufferRegistration { + size_t offset{0}; // view start in scratch_buffer_ + size_t aligned_bytes{0}; // aligned byte span + bool is_bump_tail{true}; // bump vs pinned@0 + infinicore::Shape shape; + infinicore::DataType dtype; + infinicore::Device device; + infinicore::Tensor bound_view; // set in finalize_and_bind + }; + + void _register_buffer_impl(const std::string &name, + size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device, + bool bump_tail); + + infinicore::Tensor _make_runtime_view(size_t offset, + const infinicore::Shape &shape, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + bool finalized_{false}; + + infinicore::Tensor scratch_buffer_; + size_t total_bytes_{0}; + size_t scratch_buffer_offset_{0}; + + std::unordered_map registrations_; + std::unordered_map runtime_buffers_; +}; + +} // namespace infinilm::global_state diff --git a/csrc/layers/attention/attention.cpp b/csrc/layers/attention/attention.cpp index 1b87f6fb..7dc9a5c2 100644 --- a/csrc/layers/attention/attention.cpp +++ b/csrc/layers/attention/attention.cpp @@ -1,17 +1,21 @@ #include "attention.hpp" +#include "../../global_state/global_state.hpp" #include "../../utils.hpp" #include "../rotary_embedding/rotary_embedding.hpp" +#include +#include namespace infinilm::layers::attention { Attention::Attention(std::shared_ptr model_config, size_t layer_idx, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { layer_idx_ = layer_idx; hidden_size_ = model_config->get("hidden_size"); head_dim_ = model_config->get("head_dim"); - const auto &dtype{model_config->get_dtype()}; size_t total_num_heads = model_config->get("num_attention_heads"); size_t total_num_kv_heads = model_config->get("num_key_value_heads"); bool use_bias = model_config->get_or("attention_bias", true); @@ -31,18 +35,24 @@ Attention::Attention(std::shared_ptr model_config qkv_proj_ = std::make_shared( hidden_size_, head_dim_, total_num_heads, total_num_kv_heads, "q_proj", "k_proj", "v_proj", register_fn, - quantization_method, use_bias, dtype, device, rank_info); + quantization_method, use_bias, dtype_, device_, rank_info); o_proj_ = this->register_module( "o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method, - use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + use_output_bias, dtype_, device_, tp_rank, tp_size, rank_info.comm); - rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device_); float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_); - init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); + init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); + + rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor Attention::forward(const infinicore::Tensor &positions, @@ -62,7 +72,14 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position size_t seq_len = shape[1]; // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Attention_qkv_output", {batch_size, seq_len, rank_qkv_output_size_}, dtype_, device_); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); @@ -89,9 +106,14 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position // 5. Attn Backend calculate auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); - // 7. Project output - auto output = o_proj_->forward(attn_output); - return output; + // 6. Project output + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Attention_o_output", {batch_size, seq_len, hidden_size_}, dtype_, device_); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); } infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -106,7 +128,14 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Attention_qkv_output", {1, seq_len, rank_qkv_output_size_}, dtype_, device_); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -133,8 +162,36 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output - auto output = o_proj_->forward(attn_output); - return output; + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Attention_o_output", {1, seq_len, hidden_size_}, dtype_, device_); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); +} + +void Attention::_register_inference_buffer() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + ASSERT(rank_qkv_output_size_ > 0 && hidden_size_ > 0); + + const std::string attention_cache_key = std::string("Attention_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + + std::to_string(rank_qkv_output_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + const size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + const infinicore::Shape attention_buffer_shape = {max_num_batched_tokens * max_output_size}; + workspace_manager.register_buffer( + attention_cache_key, + attention_buffer_shape, + dtype_, + device_); } void init_kv_cache_quant_params(std::function register_fn, diff --git a/csrc/layers/attention/attention.hpp b/csrc/layers/attention/attention.hpp index 31f0d1fa..7814f730 100644 --- a/csrc/layers/attention/attention.hpp +++ b/csrc/layers/attention/attention.hpp @@ -37,6 +37,8 @@ class Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; + void _register_inference_buffer(); + protected: std::shared_ptr qkv_proj_; std::shared_ptr o_proj_; @@ -49,13 +51,19 @@ class Attention : public infinicore::nn::Module { size_t num_key_value_heads_; size_t hidden_size_; size_t head_dim_; + infinicore::Device device_; + infinicore::DataType dtype_; // For off-line kv cache quantization INFINICORE_NN_PARAMETER(kv_cache_k_scale); INFINICORE_NN_PARAMETER(kv_cache_v_scale); + +private: + bool enable_workspace_manager_{false}; + size_t rank_qkv_output_size_{0}; }; void init_kv_cache_quant_params(std::function register_fn, - const infinicore::Device &device, - infinicore::nn::Parameter &kv_cache_k_scale, - infinicore::nn::Parameter &kv_cache_v_scale); + const infinicore::Device &device, + infinicore::nn::Parameter &kv_cache_k_scale, + infinicore::nn::Parameter &kv_cache_v_scale); } // namespace infinilm::layers::attention diff --git a/csrc/layers/attention/backends/attention_layer.cpp b/csrc/layers/attention/backends/attention_layer.cpp index fcaefa29..e5e39c10 100644 --- a/csrc/layers/attention/backends/attention_layer.cpp +++ b/csrc/layers/attention/backends/attention_layer.cpp @@ -9,16 +9,17 @@ AttentionLayer::AttentionLayer(size_t num_heads, size_t layer_idx, infinicore::Tensor k_scale, infinicore::Tensor v_scale, - ::infinilm::backends::AttentionBackend attn_backend) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) { + ::infinilm::backends::AttentionBackend attn_backend, + const infinicore::Device &device) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) { switch (attn_backend) { case ::infinilm::backends::AttentionBackend::STATIC_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; case ::infinilm::backends::AttentionBackend::PAGED_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; case ::infinilm::backends::AttentionBackend::FLASH_ATTN: - attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx, device); break; default: throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend"); diff --git a/csrc/layers/attention/backends/attention_layer.hpp b/csrc/layers/attention/backends/attention_layer.hpp index 87411062..d83a7979 100644 --- a/csrc/layers/attention/backends/attention_layer.hpp +++ b/csrc/layers/attention/backends/attention_layer.hpp @@ -31,7 +31,8 @@ class AttentionLayer { size_t layer_idx, infinicore::Tensor k_scale, infinicore::Tensor v_scale, - ::infinilm::backends::AttentionBackend attention_backend); + ::infinilm::backends::AttentionBackend attention_backend, + const infinicore::Device &device); infinicore::Tensor forward(infinicore::Tensor &query, infinicore::Tensor &key, diff --git a/csrc/layers/attention/backends/flash_attn.cpp b/csrc/layers/attention/backends/flash_attn.cpp index ec7e3772..14b94b54 100644 --- a/csrc/layers/attention/backends/flash_attn.cpp +++ b/csrc/layers/attention/backends/flash_attn.cpp @@ -1,9 +1,11 @@ #include "flash_attn.hpp" +#include "../../../global_state/global_state.hpp" #include "../../../utils.hpp" #include "infinicore/ops.hpp" #include "infinicore/ops/mha_kvcache.hpp" #include "infinicore/ops/mha_varlen.hpp" +#include namespace infinilm::layers::attention::backends { @@ -11,19 +13,29 @@ FlashAttentionImpl::FlashAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device &device) : num_heads_(num_heads), head_size_(head_size), scale_(scale), num_kv_heads_(num_kv_heads), layer_idx_(layer_idx), - head_dim_(head_size) { + head_dim_(head_size), + device_(device) { const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); if (!infinilm_config.model_config) { throw std::runtime_error("infinilm::layers::attention::backends::FlashAttentionImpl: model_config is null"); } - max_position_embeddings_ = infinilm_config.model_config->get("max_position_embeddings"); + + const auto &model_config = infinilm_config.model_config; + dtype_ = model_config->get_dtype(); + max_position_embeddings_ = model_config->get("max_position_embeddings"); + + enable_workspace_manager_ = infinilm_config.enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, @@ -48,8 +60,14 @@ infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); // 2. Compute attention - infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::Tensor attn_output; if (is_prefill) { + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + attn_output = workspace_manager.get_buffer("FlashAttention_attn_output", {seq_len, num_heads_, head_dim_}, dtype_, device_); + } else { + attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); + } infinicore::op::mha_varlen_( attn_output, query, @@ -99,4 +117,23 @@ std::tuple FlashAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } +void FlashAttentionImpl::_register_inference_buffer() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string cache_key = std::string("FlashAttentionImpl_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_num_heads_" + + std::to_string(num_heads_) + "_head_dim_" + + std::to_string(head_dim_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + const infinicore::Shape flash_attn_buffer_shape = {max_num_batched_tokens, num_heads_, head_dim_}; + workspace_manager.register_buffer( + cache_key, + flash_attn_buffer_shape, + dtype_, + device_); +} } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_attn.hpp b/csrc/layers/attention/backends/flash_attn.hpp index 93f61e8b..1f513015 100644 --- a/csrc/layers/attention/backends/flash_attn.hpp +++ b/csrc/layers/attention/backends/flash_attn.hpp @@ -16,7 +16,8 @@ class FlashAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); /** * @brief Forward pass with FlashAttention. @@ -43,6 +44,9 @@ class FlashAttentionImpl { const infinicore::Tensor slot_mapping) const; private: + void _register_inference_buffer(); + bool enable_workspace_manager_{false}; + size_t num_heads_; size_t head_size_; float scale_; @@ -50,5 +54,8 @@ class FlashAttentionImpl { size_t layer_idx_; size_t head_dim_; // Note: head_dim equals to head_size size_t max_position_embeddings_; + infinicore::Device device_; + infinicore::DataType dtype_; }; + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.cpp b/csrc/layers/attention/backends/paged_attn.cpp index a0ad70af..7d6eec5d 100644 --- a/csrc/layers/attention/backends/paged_attn.cpp +++ b/csrc/layers/attention/backends/paged_attn.cpp @@ -1,21 +1,39 @@ #include "paged_attn.hpp" +#include "../../../global_state/global_state.hpp" #include "../../../utils.hpp" #include "infinicore/ops.hpp" +#include + namespace infinilm::layers::attention::backends { PagedAttentionImpl::PagedAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device &device) : num_heads_(num_heads), head_size_(head_size), scale_(scale), num_kv_heads_(num_kv_heads), layer_idx_(layer_idx), - head_dim_(head_size) {} + head_dim_(head_size), + device_(device) { + + const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); + if (!infinilm_config.model_config) { + throw std::runtime_error("infinilm::layers::attention::backends::PagedAttentionImpl: model_config is null"); + } + + dtype_ = infinilm_config.model_config->get_dtype(); + + enable_workspace_manager_ = infinilm_config.enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } +} infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, const infinicore::Tensor &query, @@ -37,7 +55,13 @@ infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); // 2. Compute attention - infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::Tensor attn_output; + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + attn_output = workspace_manager.get_buffer("PagedAttention_attn_output", {seq_len, num_heads_, head_dim_}, dtype_, device_); + } else { + attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, dtype_, device_); + } if (is_prefill) { infinicore::op::paged_attention_prefill_( attn_output, @@ -80,4 +104,25 @@ std::tuple PagedAttentionImpl::do_kv_cac return {k_cache_layer, v_cache_layer}; } + +void PagedAttentionImpl::_register_inference_buffer() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + const std::string cache_key = std::string("PagedAttentionImpl_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_num_heads_" + + std::to_string(num_heads_) + "_head_dim_" + + std::to_string(head_dim_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + const infinicore::Shape paged_attn_buffer_shape = {max_num_batched_tokens, num_heads_, head_dim_}; + workspace_manager.register_buffer( + cache_key, + paged_attn_buffer_shape, + dtype_, + device_); +} + } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.hpp b/csrc/layers/attention/backends/paged_attn.hpp index 4f53ea57..e1ab7688 100644 --- a/csrc/layers/attention/backends/paged_attn.hpp +++ b/csrc/layers/attention/backends/paged_attn.hpp @@ -16,7 +16,8 @@ class PagedAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); /** * @brief Forward pass with PagedAttention. @@ -43,11 +44,16 @@ class PagedAttentionImpl { const infinicore::Tensor slot_mapping) const; private: + void _register_inference_buffer(); + bool enable_workspace_manager_{false}; + size_t num_heads_; size_t head_size_; float scale_; size_t num_kv_heads_; size_t layer_idx_; size_t head_dim_; // Note: head_dim equals to head_size + infinicore::Device device_; + infinicore::DataType dtype_; }; } // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/static_attn.cpp b/csrc/layers/attention/backends/static_attn.cpp index 2d1b7e11..668f4c21 100644 --- a/csrc/layers/attention/backends/static_attn.cpp +++ b/csrc/layers/attention/backends/static_attn.cpp @@ -11,7 +11,8 @@ StaticAttentionImpl::StaticAttentionImpl(size_t num_heads, size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx) + size_t layer_idx, + const infinicore::Device & /*device*/) : num_heads_(num_heads), head_size_(head_size), scale_(scale), diff --git a/csrc/layers/attention/backends/static_attn.hpp b/csrc/layers/attention/backends/static_attn.hpp index 849d8792..00af4391 100644 --- a/csrc/layers/attention/backends/static_attn.hpp +++ b/csrc/layers/attention/backends/static_attn.hpp @@ -18,7 +18,8 @@ class StaticAttentionImpl { size_t head_size, float scale, size_t num_kv_heads, - size_t layer_idx); + size_t layer_idx, + const infinicore::Device &device); infinicore::Tensor forward(const AttentionLayer &layer, infinicore::Tensor &q_reshaped, // query diff --git a/csrc/layers/causal_lm_templates/text_causal_lm.hpp b/csrc/layers/causal_lm_templates/text_causal_lm.hpp index eb4f2b47..86ce8b1c 100644 --- a/csrc/layers/causal_lm_templates/text_causal_lm.hpp +++ b/csrc/layers/causal_lm_templates/text_causal_lm.hpp @@ -1,8 +1,12 @@ #pragma once +#include "../../global_state/global_state.hpp" #include "../../models/infinilm_model.hpp" +#include "../../utils.hpp" #include "../linear/linear.hpp" #include "infinicore/device.hpp" +#include "infinicore/dtype.hpp" +#include namespace infinilm::layers::causal_lm_templates { @@ -28,15 +32,21 @@ class TextCausalLM : public InfinilmModel { * @param device: Device to create tensors on */ TextCausalLM(std::shared_ptr model_config, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { model_config_ = model_config; size_t hidden_size = model_config->get("hidden_size"); - size_t vocab_size = model_config->get("vocab_size"); - const auto &dtype{model_config->get_dtype()}; + vocab_size_ = model_config->get("vocab_size"); model_ = this->register_module("model", model_config, device); - lm_head_ = this->register_module("lm_head", hidden_size, vocab_size, false, dtype, device); + lm_head_ = this->register_module("lm_head", hidden_size, vocab_size_, false, dtype_, device_); + + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } /** @@ -44,7 +54,19 @@ class TextCausalLM : public InfinilmModel { */ Output forward(const Input &input) const override { auto hidden_states = model_->forward(input); - auto logits = lm_head_->forward(hidden_states); + infinicore::Tensor logits; + + if (enable_workspace_manager_) { + const auto shape = hidden_states->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + logits = workspace_manager.get_buffer("TextCausalLM_logits", 0, {bs, seq_len, vocab_size_}, dtype_, device_); + lm_head_->forward_(logits, hidden_states); + } else { + logits = lm_head_->forward(hidden_states); + } + return {logits}; } @@ -55,8 +77,36 @@ class TextCausalLM : public InfinilmModel { Model &model() { return *model_; } protected: + size_t vocab_size_{0}; + infinicore::Device device_; + infinicore::DataType dtype_; + INFINICORE_NN_MODULE(Model, model); INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); + +private: + void _register_inference_buffer() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + ASSERT(vocab_size_ > 0); + + const std::string text_causal_lm_cache_key = std::string("TextCausalLM_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_vocab_size_" + + std::to_string(vocab_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + const infinicore::Shape logits_shape = {max_num_batched_tokens, vocab_size_}; + workspace_manager.register_buffer( + text_causal_lm_cache_key, + 0, + logits_shape, + dtype_, + device_); + } + + bool enable_workspace_manager_{false}; }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/causal_lm_templates/text_model.hpp b/csrc/layers/causal_lm_templates/text_model.hpp index 62a52798..376a7f0c 100644 --- a/csrc/layers/causal_lm_templates/text_model.hpp +++ b/csrc/layers/causal_lm_templates/text_model.hpp @@ -1,11 +1,14 @@ #pragma once #include "../../config/model_config.hpp" +#include "../../global_state/global_state.hpp" #include "../../models/infinilm_model.hpp" +#include "../../utils.hpp" #include "infinicore/nn/embedding.hpp" #include "infinicore/nn/rmsnorm.hpp" #include "infinicore/tensor.hpp" #include +#include #include namespace infinilm::layers::causal_lm_templates { @@ -24,30 +27,46 @@ template class TextModel : public infinicore::nn::Module { public: TextModel(std::shared_ptr model_config, - const infinicore::Device &device) { - const auto &dtype{model_config->get_dtype()}; - size_t vocab_size = model_config->get("vocab_size"); - size_t hidden_size = model_config->get("hidden_size"); + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { + vocab_size_ = model_config->get("vocab_size"); + hidden_size_ = model_config->get("hidden_size"); size_t max_position_embeddings = model_config->get("max_position_embeddings"); size_t num_hidden_layers = model_config->get("num_hidden_layers"); double rope_theta = model_config->get("rope_theta"); double rms_norm_eps = model_config->get("rms_norm_eps"); - embed_tokens_ = this->register_module("embed_tokens", vocab_size, hidden_size, std::nullopt, dtype, device); + embed_tokens_ = this->register_module("embed_tokens", vocab_size_, hidden_size_, std::nullopt, dtype_, device_); layers_.reserve(num_hidden_layers); for (size_t i = 0; i < num_hidden_layers; ++i) { - layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device)); + layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device_)); } - norm_ = this->register_module("norm", hidden_size, rms_norm_eps, dtype, device); + norm_ = this->register_module("norm", hidden_size_, rms_norm_eps, dtype_, device_); + + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const { auto input_ids = input.input_ids.value(); auto positions = input.position_ids.value(); // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] - auto hidden_states = embed_tokens_->forward(input_ids); + infinicore::Tensor hidden_states; + if (enable_workspace_manager_) { + const auto shape = input_ids->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + hidden_states = workspace_manager.get_buffer("TextModel_hidden_states", {bs, seq_len, hidden_size_}, dtype_, device_); + embed_tokens_->forward_(hidden_states, input_ids); + } else { + hidden_states = embed_tokens_->forward(input_ids); + } // 2. Process through all decoder layers size_t num_layers = layers_.size(); @@ -64,6 +83,7 @@ class TextModel : public infinicore::nn::Module { } infinicore::Tensor forward_naive(const infinilm::InfinilmModel::Input &input) const { + // Don't use preallocated workspace in forward_naive function. auto input_ids = input.input_ids.value(); auto positions = input.position_ids.value(); auto hidden_states = embed_tokens_->forward(input_ids); @@ -78,6 +98,7 @@ class TextModel : public infinicore::nn::Module { infinicore::Tensor forward_embeds(const infinicore::Tensor &inputs_embeds, const infinicore::Tensor &position_ids) const { + // Don't use preallocated workspace in forward_embeds function. auto hidden_states = inputs_embeds; // Process through all decoder layers @@ -98,10 +119,39 @@ class TextModel : public infinicore::nn::Module { return embed_tokens_->forward(input_ids); } +private: + void _register_inference_buffer() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + ASSERT((hidden_size_ > 0) && (vocab_size_ > 0)); + + const std::string text_model_cache_key = std::string("TextModel_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + const infinicore::Shape hidden_states_shape = {max_num_batched_tokens, hidden_size_}; + workspace_manager.register_buffer( + text_model_cache_key, + hidden_states_shape, + dtype_, + device_); + } + protected: + size_t vocab_size_{0}; + size_t hidden_size_{0}; + infinicore::Device device_; + infinicore::DataType dtype_; + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); INFINICORE_NN_MODULE_VEC(DecoderLayer, layers); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + +private: + bool enable_workspace_manager_{false}; }; } // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/linear/base_linear.cpp b/csrc/layers/linear/base_linear.cpp index eebc482c..25b29260 100644 --- a/csrc/layers/linear/base_linear.cpp +++ b/csrc/layers/linear/base_linear.cpp @@ -43,10 +43,25 @@ infinicore::Tensor BaseLinear::compute_linear(infinicore::Tensor &input) const { return quantization_->forward(params, input, has_bias_, alpha_); } +void BaseLinear::compute_linear_(infinicore::Tensor &output, infinicore::Tensor &input) const { + // Build params map from direct parameters only (not state_dict which uses a + // static local and is not thread-safe across RankWorker threads). + infinilm::quantization::ParamsMap params; + for (const auto &[name, param] : parameters_) { + params[name] = static_cast(param); + } + + quantization_->forward_(output, params, input, has_bias_, alpha_); +} + infinicore::Tensor BaseLinear::forward(infinicore::Tensor &input) const { return compute_linear(input); } +void BaseLinear::forward_(infinicore::Tensor &output, infinicore::Tensor &input) const { + compute_linear_(output, input); +} + infinicore::Tensor BaseLinear::forward(infinicore::Tensor &input, infinicore::Tensor &residual) const { auto output = compute_linear(input); infinicore::op::add_(output, output, residual); @@ -60,7 +75,9 @@ void BaseLinear::process_weights_after_loading() { } auto new_quant = quantization_->process_weights_after_loading(params, device_); - if (!new_quant) return; + if (!new_quant) { + return; + } for (auto &[name, param] : parameters_) { param = infinicore::nn::Parameter(); @@ -68,7 +85,9 @@ void BaseLinear::process_weights_after_loading() { for (const auto &[name, tensor] : params) { auto it = parameters_.find(name); - if (it == parameters_.end()) continue; + if (it == parameters_.end()) { + continue; + } it->second = infinicore::nn::Parameter(tensor); } @@ -79,43 +98,61 @@ void BaseLinear::process_weights_after_loading() { infinicore::Tensor BaseLinear::weight() const { auto it = parameters_.find("weight"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("qweight"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::bias() const { auto it = parameters_.find("bias"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::weight_scale() const { auto it = parameters_.find("weight_scale"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("scales"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::weight_zeros() const { auto it = parameters_.find("weight_zeros"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } it = parameters_.find("qzeros"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::gidx() const { auto it = parameters_.find("g_idx"); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } infinicore::Tensor BaseLinear::get_param(const std::string &name) const { auto it = parameters_.find(name); - if (it != parameters_.end()) return it->second; + if (it != parameters_.end()) { + return it->second; + } return infinicore::Tensor(); } diff --git a/csrc/layers/linear/base_linear.hpp b/csrc/layers/linear/base_linear.hpp index a304f457..52ac27c4 100644 --- a/csrc/layers/linear/base_linear.hpp +++ b/csrc/layers/linear/base_linear.hpp @@ -1,8 +1,8 @@ #pragma once -#include "infinicore/ops.hpp" #include "../quantization/quantization.hpp" #include "infinicore/nn/module.hpp" +#include "infinicore/ops.hpp" #include #include @@ -23,6 +23,8 @@ class BaseLinear : public infinicore::nn::Module { // Forward pass: output = input @ weight.T + bias infinicore::Tensor forward(infinicore::Tensor &input) const; + void forward_(infinicore::Tensor &output, infinicore::Tensor &input) const; + // Forward pass with residual connection infinicore::Tensor forward(infinicore::Tensor &input, infinicore::Tensor &residual) const; @@ -57,6 +59,7 @@ class BaseLinear : public infinicore::nn::Module { protected: infinicore::Tensor compute_linear(infinicore::Tensor &input) const; + void compute_linear_(infinicore::Tensor &output, infinicore::Tensor &input) const; size_t in_features_; size_t out_features_; diff --git a/csrc/layers/linear/fused_linear.cpp b/csrc/layers/linear/fused_linear.cpp index 2c052090..dd0da59d 100644 --- a/csrc/layers/linear/fused_linear.cpp +++ b/csrc/layers/linear/fused_linear.cpp @@ -70,6 +70,17 @@ QKVParallelLinear::forward_split(infinicore::Tensor &input) { return std::make_tuple(q_out, k_out, v_out); } +std::tuple +QKVParallelLinear::forward_split_(infinicore::Tensor &output, infinicore::Tensor &input) { + this->forward_(output, input); + + auto q_out = output->narrow({{2, 0, q_out_size_}}); + auto k_out = output->narrow({{2, q_out_size_, k_out_size_}}); + auto v_out = output->narrow({{2, q_out_size_ + k_out_size_, v_out_size_}}); + + return std::make_tuple(q_out, k_out, v_out); +} + bool QKVParallelLinear::has_q_bias() const { return q_bias_; } bool QKVParallelLinear::has_k_bias() const { return k_bias_; } bool QKVParallelLinear::has_v_bias() const { return v_bias_; } @@ -144,6 +155,15 @@ std::tuple GateUpParallelLinear::forward return std::make_tuple(gate_output, up_output); } +std::tuple +GateUpParallelLinear::forward_split_(infinicore::Tensor &output, infinicore::Tensor &input) { + this->forward_(output, input); + auto cols = output->shape()[2]; + auto gate_output = output->narrow({{2, 0, cols / 2}}); + auto up_output = output->narrow({{2, cols / 2, cols / 2}}); + return std::make_tuple(gate_output, up_output); +} + bool GateUpParallelLinear::has_gate_bias() const { return gate_bias_; } bool GateUpParallelLinear::has_up_bias() const { return up_bias_; } diff --git a/csrc/layers/linear/fused_linear.hpp b/csrc/layers/linear/fused_linear.hpp index 6e4a3485..b85aa3dd 100644 --- a/csrc/layers/linear/fused_linear.hpp +++ b/csrc/layers/linear/fused_linear.hpp @@ -43,6 +43,9 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear { std::tuple forward_split(infinicore::Tensor &input); + std::tuple + forward_split_(infinicore::Tensor &output, infinicore::Tensor &input); + bool has_q_bias() const; bool has_k_bias() const; bool has_v_bias() const; @@ -109,6 +112,9 @@ class GateUpParallelLinear : public infinilm::nn::ColumnParallelLinear { std::tuple forward_split(infinicore::Tensor &input); + std::tuple + forward_split_(infinicore::Tensor &output, infinicore::Tensor &input); + bool has_gate_bias() const; bool has_up_bias() const; diff --git a/csrc/layers/linear/linear.cpp b/csrc/layers/linear/linear.cpp index 84982409..f73abbb2 100644 --- a/csrc/layers/linear/linear.cpp +++ b/csrc/layers/linear/linear.cpp @@ -93,6 +93,14 @@ infinicore::Tensor RowParallelLinear::forward(infinicore::Tensor &input) const { return output; } +void RowParallelLinear::forward_(infinicore::Tensor &output, infinicore::Tensor &input) const { + BaseLinear::forward_(output, input); + + if ((tp_size_ > 1) && (communicator_ != nullptr)) { + infinicore::op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_); + } +} + std::string RowParallelLinear::extra_repr() const { return "RowParallelLinear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; } diff --git a/csrc/layers/linear/linear.hpp b/csrc/layers/linear/linear.hpp index 566cee77..abae08f6 100644 --- a/csrc/layers/linear/linear.hpp +++ b/csrc/layers/linear/linear.hpp @@ -70,6 +70,9 @@ class RowParallelLinear : public BaseLinear { infinicclComm_t communicator = nullptr); infinicore::Tensor forward(infinicore::Tensor &input) const; + + void forward_(infinicore::Tensor &output, infinicore::Tensor &input) const; + std::string extra_repr() const; protected: diff --git a/csrc/layers/mlp/mlp.cpp b/csrc/layers/mlp/mlp.cpp index f7604c50..095bd62a 100644 --- a/csrc/layers/mlp/mlp.cpp +++ b/csrc/layers/mlp/mlp.cpp @@ -1,13 +1,16 @@ #include "mlp.hpp" #include "../../global_state/global_state.hpp" +#include "../../utils.hpp" #include "infinicore/ops.hpp" +#include namespace infinilm::layers::mlp { MLP::MLP(std::shared_ptr model_config, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { - const auto &dtype{model_config->get_dtype()}; hidden_size_ = model_config->get("hidden_size"); intermediate_size_ = model_config->get("intermediate_size"); use_bias_ = model_config->get_or("mlp_bias", false); @@ -20,13 +23,25 @@ MLP::MLP(std::shared_ptr model_config, auto register_fn = [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }; gate_up_proj_ = std::make_shared( hidden_size_, intermediate_size_, "gate_proj", "up_proj", register_fn, - quantization_method, use_bias_, dtype, device, rank_info); + quantization_method, use_bias_, dtype_, device_, rank_info); down_proj_ = this->register_module( "down_proj", intermediate_size_, hidden_size_, quantization_method, - use_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); + use_bias_, dtype_, device_, tp_rank, tp_size, rank_info.comm); + + rank_gate_up_output_size_ = gate_up_proj_->out_features() / static_cast(tp_size); + rank_intermediate_size_ = rank_gate_up_output_size_ / 2; + + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { + if (enable_workspace_manager_) { + return this->_forward_with_inference_buffer(hidden_states); + } + // 1. Project to gate and up auto hidden_states_mutable = hidden_states; auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); @@ -36,4 +51,59 @@ infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { auto output = down_proj_->forward(intermediate); return output; } + +infinicore::Tensor MLP::_forward_with_inference_buffer(const infinicore::Tensor &hidden_states) const { + + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + + const auto shape = hidden_states->shape(); + const size_t bs = shape[0]; + const size_t seq_len = shape[1]; + + auto hidden_states_mutable = hidden_states; + // 1. Project to gate and up + auto gate_up_output = workspace_manager.get_buffer("MLP_gate_up_output", {bs, seq_len, rank_gate_up_output_size_}, dtype_, device_); + auto [gate, up] = gate_up_proj_->forward_split_(gate_up_output, hidden_states_mutable); + + // 2. Apply SwiGLU: silu(gate) * up + auto intermediate = workspace_manager.get_buffer("MLP_intermediate", {bs, seq_len, rank_intermediate_size_}, dtype_, device_); + infinicore::op::swiglu_(intermediate, up, gate); + + // 3. Project down + auto down_output = workspace_manager.get_buffer("MLP_down_output", {bs, seq_len, hidden_size_}, dtype_, device_); + down_proj_->forward_(down_output, intermediate); + return down_output; +} + +void MLP::_register_inference_buffer() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + ASSERT(rank_gate_up_output_size_ > 0 && rank_intermediate_size_ > 0 && hidden_size_ > 0 && intermediate_size_ > 0); + + const std::string mlp_cache_key = std::string("MLP_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_gate_up_output_size_" + + std::to_string(rank_gate_up_output_size_) + "_rank_intermediate_size_" + + std::to_string(rank_intermediate_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + auto align_up = [](size_t n, size_t alignment = 512) { + return (n + alignment - 1) & ~(alignment - 1); + }; + + const size_t rank_gate_up_output_size_aligned = align_up(rank_gate_up_output_size_); + const size_t rank_intermediate_size_aligned = align_up(rank_gate_up_output_size_aligned + rank_intermediate_size_); + const size_t max_output_size = rank_intermediate_size_aligned + hidden_size_; + + const infinicore::Shape mlp_buffer_shape = {max_num_batched_tokens * max_output_size}; + workspace_manager.register_buffer( + mlp_cache_key, + mlp_buffer_shape, + dtype_, + device_); +} + } // namespace infinilm::layers::mlp diff --git a/csrc/layers/mlp/mlp.hpp b/csrc/layers/mlp/mlp.hpp index 91349fe9..fb15c9ec 100644 --- a/csrc/layers/mlp/mlp.hpp +++ b/csrc/layers/mlp/mlp.hpp @@ -3,6 +3,7 @@ #include "../../config/model_config.hpp" #include "../linear/linear.hpp" #include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" namespace infinilm::layers::mlp { @@ -51,6 +52,17 @@ class MLP : public infinicore::nn::Module { size_t hidden_size_; size_t intermediate_size_; bool use_bias_; + infinicore::Device device_; + infinicore::DataType dtype_; + +private: + infinicore::Tensor _forward_with_inference_buffer(const infinicore::Tensor &hidden_states) const; + + void _register_inference_buffer(); + + bool enable_workspace_manager_{false}; + size_t rank_gate_up_output_size_{0}; + size_t rank_intermediate_size_{0}; }; } // namespace infinilm::layers::mlp diff --git a/csrc/layers/quantization/awq.cpp b/csrc/layers/quantization/awq.cpp index 50e830f4..1c07c6db 100644 --- a/csrc/layers/quantization/awq.cpp +++ b/csrc/layers/quantization/awq.cpp @@ -52,6 +52,26 @@ infinicore::Tensor AWQ::forward( return infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt); } +void AWQ::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto qweight = params.at("qweight"); + auto scales = params.at("scales"); + auto qzeros = params.at("qzeros"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_w4a16_awq_(output, input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt); +} + std::vector AWQ::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/awq.hpp b/csrc/layers/quantization/awq.hpp index 383e574a..797092cb 100644 --- a/csrc/layers/quantization/awq.hpp +++ b/csrc/layers/quantization/awq.hpp @@ -38,6 +38,13 @@ class AWQ : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/base_quantization.hpp b/csrc/layers/quantization/base_quantization.hpp index 1fd261bd..0ed82cb3 100644 --- a/csrc/layers/quantization/base_quantization.hpp +++ b/csrc/layers/quantization/base_quantization.hpp @@ -59,6 +59,14 @@ class BaseQuantization : public std::enable_shared_from_this { bool has_bias, float alpha = 1.0f) const = 0; + // In-place forward pass. + virtual void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const = 0; + // Dimension for fused-split (gate/up, q/k/v) of a column-parallel weight. // For NoneQuantization weight [out, in], split is on dim0. // For AWQ qweight [in, out/pack], split is on dim1. diff --git a/csrc/layers/quantization/compressed_tensors.cpp b/csrc/layers/quantization/compressed_tensors.cpp index 66a4a3ef..45c46fb3 100644 --- a/csrc/layers/quantization/compressed_tensors.cpp +++ b/csrc/layers/quantization/compressed_tensors.cpp @@ -43,6 +43,25 @@ infinicore::Tensor CompressedTensors::forward( return infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight, weight_scale, bias_opt); } +void CompressedTensors::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto weight = params.at("weight"); + auto weight_scale = params.at("weight_scale"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_w8a8i8_(output, input_contiguous->contiguous(), weight, weight_scale, bias_opt); +} + std::vector CompressedTensors::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/compressed_tensors.hpp b/csrc/layers/quantization/compressed_tensors.hpp index dcf65c2e..2bac396a 100644 --- a/csrc/layers/quantization/compressed_tensors.hpp +++ b/csrc/layers/quantization/compressed_tensors.hpp @@ -25,6 +25,13 @@ class CompressedTensors : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq.cpp b/csrc/layers/quantization/gptq.cpp index e7688be5..972aa130 100644 --- a/csrc/layers/quantization/gptq.cpp +++ b/csrc/layers/quantization/gptq.cpp @@ -36,6 +36,16 @@ infinicore::Tensor GPTQ::forward( "Call process_weights_after_loading() first."); } +void GPTQ::forward_( + infinicore::Tensor & /*output*/, + const ParamsMap & /*params*/, + const infinicore::Tensor & /*input*/, + bool /*has_bias*/, + float /*alpha*/) const { + throw std::runtime_error("GPTQ_W4A16 must be converted to GPTQ_QY before forward pass. " + "Call process_weights_after_loading() first."); +} + std::shared_ptr GPTQ::process_weights_after_loading( ParamsMap ¶ms, const infinicore::Device &device) const { diff --git a/csrc/layers/quantization/gptq.hpp b/csrc/layers/quantization/gptq.hpp index 455dde2c..598be78f 100644 --- a/csrc/layers/quantization/gptq.hpp +++ b/csrc/layers/quantization/gptq.hpp @@ -34,6 +34,13 @@ class GPTQ : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq_qy.cpp b/csrc/layers/quantization/gptq_qy.cpp index 4098e452..42ac9b26 100644 --- a/csrc/layers/quantization/gptq_qy.cpp +++ b/csrc/layers/quantization/gptq_qy.cpp @@ -48,6 +48,25 @@ infinicore::Tensor GPTQ_QY::forward( return output; } +void GPTQ_QY::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float /*alpha*/) const { + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto qweight = params.at("qweight"); + auto qzeros = params.at("qzeros"); + auto scales = params.at("scales"); + + infinicore::op::linear_w4a16_gptq_qy_(output, input_contiguous->contiguous(), qweight, scales, qzeros, 0, 4); + + if (has_bias) { + auto bias = params.at("bias"); + infinicore::op::add_(output, output, bias->as_strided(output->shape(), {0, 0, 1})); + } +} + std::vector GPTQ_QY::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/gptq_qy.hpp b/csrc/layers/quantization/gptq_qy.hpp index 634b4aaf..22635fb4 100644 --- a/csrc/layers/quantization/gptq_qy.hpp +++ b/csrc/layers/quantization/gptq_qy.hpp @@ -112,6 +112,13 @@ class GPTQ_QY : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + // Split fused linear parameters into named sub-parameters std::vector split_params( const std::unordered_map ¶ms, diff --git a/csrc/layers/quantization/none_quantization.cpp b/csrc/layers/quantization/none_quantization.cpp index e1f67a7d..3525b4e1 100644 --- a/csrc/layers/quantization/none_quantization.cpp +++ b/csrc/layers/quantization/none_quantization.cpp @@ -38,6 +38,24 @@ infinicore::Tensor NoneQuantization::forward( return infinicore::op::linear(input_contiguous->contiguous(), weight->contiguous(), bias_opt, alpha); } +void NoneQuantization::forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha) const { + + auto input_contiguous = input->is_contiguous() ? input : input->contiguous(); + auto weight = params.at("weight"); + + std::optional bias_opt; + if (has_bias) { + bias_opt = params.at("bias"); + } + + infinicore::op::linear_(output, input_contiguous->contiguous(), weight->contiguous(), bias_opt, alpha); +} + std::vector NoneQuantization::split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/layers/quantization/none_quantization.hpp b/csrc/layers/quantization/none_quantization.hpp index 18fe4bf1..baee47c6 100644 --- a/csrc/layers/quantization/none_quantization.hpp +++ b/csrc/layers/quantization/none_quantization.hpp @@ -27,6 +27,13 @@ class NoneQuantization : public BaseQuantization { bool has_bias, float alpha = 1.0f) const override; + void forward_( + infinicore::Tensor &output, + const ParamsMap ¶ms, + const infinicore::Tensor &input, + bool has_bias, + float alpha = 1.0f) const override; + std::vector split_params( const std::unordered_map ¶ms, const std::vector &splits, diff --git a/csrc/models/glm4/glm4_attention.cpp b/csrc/models/glm4/glm4_attention.cpp index c5851cc2..579b0e89 100644 --- a/csrc/models/glm4/glm4_attention.cpp +++ b/csrc/models/glm4/glm4_attention.cpp @@ -57,7 +57,7 @@ Glm4Attention::Glm4Attention(std::shared_ptr mode attn_ = std::make_shared( num_attention_heads_, head_dim_, scaling_, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); // KV Cache quantization scale initialization infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); diff --git a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp index 60fe4ee4..0413f274 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp +++ b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp @@ -49,7 +49,7 @@ AttentionBase::AttentionBase(std::shared_ptr mode float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); infinilm::layers::attention::init_kv_cache_quant_params([this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }, device, kv_cache_k_scale_, kv_cache_v_scale_); diff --git a/csrc/models/qwen3/qwen3_attention.cpp b/csrc/models/qwen3/qwen3_attention.cpp index bff9c8d7..745d877e 100644 --- a/csrc/models/qwen3/qwen3_attention.cpp +++ b/csrc/models/qwen3/qwen3_attention.cpp @@ -2,17 +2,20 @@ #include "../../global_state/global_state.hpp" #include "../../layers/attention/attention.hpp" #include "../../utils.hpp" +#include +#include namespace infinilm::models::qwen3 { Qwen3Attention::Qwen3Attention(std::shared_ptr model_config, size_t layer_idx, - const infinicore::Device &device) { + const infinicore::Device &device) + : device_(device), + dtype_(model_config->get_dtype()) { layer_idx_ = layer_idx; hidden_size_ = model_config->get("hidden_size"); head_dim_ = model_config->get("head_dim"); - const auto &dtype{model_config->get_dtype()}; size_t total_num_heads = model_config->get("num_attention_heads"); size_t total_num_kv_heads = model_config->get("num_key_value_heads"); bool use_bias = model_config->get_or("attention_bias", true); @@ -35,21 +38,27 @@ Qwen3Attention::Qwen3Attention(std::shared_ptr mo qkv_proj_ = std::make_shared( hidden_size_, head_dim_, total_num_heads, total_num_kv_heads, "q_proj", "k_proj", "v_proj", register_fn, - quantization_method, use_bias, dtype, device, rank_info); + quantization_method, use_bias, dtype_, device_, rank_info); o_proj_ = this->register_module( "o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method, - use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + use_output_bias, dtype_, device_, tp_rank, tp_size, rank_info.comm); - rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device_); float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_); - INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype_, device_); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype_, device_); - infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); + infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_); + + rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast(tp_size); + enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager; + if (enable_workspace_manager_) { + this->_register_inference_buffer(); + } } infinicore::Tensor Qwen3Attention::forward(const infinicore::Tensor &positions, @@ -70,7 +79,14 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos size_t seq_len = shape[1]; // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Qwen3Attention_qkv_output", {batch_size, seq_len, rank_qkv_output_size_}, dtype_, device_); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); @@ -100,9 +116,14 @@ infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &pos // 6. Attn Backend calculate auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped); - // 7. Project output - auto output = o_proj_->forward(attn_output); - return output; + // 6. Project output + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Qwen3Attention_o_output", {batch_size, seq_len, hidden_size_}, dtype_, device_); + o_proj_->forward_(o_output, attn_output); + return o_output; + } + return o_proj_->forward(attn_output); } infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &position_ids, @@ -118,7 +139,14 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi ASSERT_EQ(batch_size, 1); // 1. Project Q, K, V - auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + infinicore::Tensor q, k, v; + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto qkv_output = workspace_manager.get_buffer("Qwen3Attention_qkv_output", {1, seq_len, rank_qkv_output_size_}, dtype_, device_); + std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable); + } else { + std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable); + } // 2. Reshape for multi-head attention auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); @@ -147,6 +175,36 @@ infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &posi auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); // 6. Project output + if (enable_workspace_manager_) { + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + auto o_output = workspace_manager.get_buffer("Qwen3Attention_o_output", {1, seq_len, hidden_size_}, dtype_, device_); + o_proj_->forward_(o_output, attn_output); + return o_output; + } return o_proj_->forward(attn_output); } + +void Qwen3Attention::_register_inference_buffer() { + const auto &infinilm_config = infinilm::global_state::get_infinilm_config(); + auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager; + const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens; + + ASSERT(rank_qkv_output_size_ > 0 && hidden_size_ > 0); + + const std::string attention_cache_key = std::string("Qwen3Attention_max_num_batched_tokens_") + + std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_" + + std::to_string(rank_qkv_output_size_) + "_hidden_size_" + + std::to_string(hidden_size_) + "_dtype_" + + infinicore::toString(dtype_) + "_device_" + + device_.toString(); + + const size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_); + const infinicore::Shape attention_buffer_shape = {max_num_batched_tokens * max_output_size}; + workspace_manager.register_buffer( + attention_cache_key, + attention_buffer_shape, + dtype_, + device_); +} + } // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3/qwen3_attention.hpp b/csrc/models/qwen3/qwen3_attention.hpp index 44b69f38..f26f4c34 100644 --- a/csrc/models/qwen3/qwen3_attention.hpp +++ b/csrc/models/qwen3/qwen3_attention.hpp @@ -25,6 +25,8 @@ class Qwen3Attention : public infinicore::nn::Module { infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, const infinicore::Tensor &hidden_states) const; + void _register_inference_buffer(); + protected: std::shared_ptr qkv_proj_; std::shared_ptr o_proj_; @@ -39,9 +41,16 @@ class Qwen3Attention : public infinicore::nn::Module { size_t num_key_value_heads_; size_t hidden_size_; size_t head_dim_; + infinicore::Device device_; + infinicore::DataType dtype_; // For off-line kv cache quantization INFINICORE_NN_PARAMETER(kv_cache_k_scale); INFINICORE_NN_PARAMETER(kv_cache_v_scale); + + size_t rank_qkv_output_size_; + bool enable_workspace_manager_{false}; + infinicore::Tensor max_qkv_output_; // inference buffer for Attention + infinicore::Tensor max_o_output_; }; } // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3_next/qwen3_next_attention.cpp b/csrc/models/qwen3_next/qwen3_next_attention.cpp index 67fd3808..5cf469d5 100644 --- a/csrc/models/qwen3_next/qwen3_next_attention.cpp +++ b/csrc/models/qwen3_next/qwen3_next_attention.cpp @@ -49,7 +49,7 @@ Qwen3NextAttention::Qwen3NextAttention(std::shared_ptr(head_dim_)); attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device); INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 8e470984..74fb22c9 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -1,5 +1,6 @@ #include "../../engine/infer_engine.hpp" #include "infinicore/tensor.hpp" +#include #include #include @@ -39,7 +40,8 @@ inline void bind_infer_engine(py::module &m) { std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, - std::optional kv_cache_dtype) { + std::optional kv_cache_dtype, + size_t max_num_batched_tokens) { return std::make_shared( model_path, dist, @@ -47,7 +49,8 @@ inline void bind_infer_engine(py::module &m) { cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, infinilm::backends::parse_attention_backend(attention_backend), - kv_cache_dtype); + kv_cache_dtype, + max_num_batched_tokens); }), py::arg("model_path") = "", py::arg("distributed_config") = distributed::DistConfig(), @@ -55,7 +58,8 @@ inline void bind_infer_engine(py::module &m) { py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, py::arg("attention_backend") = "default", - py::arg("kv_cache_dtype") = py::none()) + py::arg("kv_cache_dtype") = py::none(), + py::arg("max_num_batched_tokens") = 2048) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -76,6 +80,10 @@ inline void bind_infer_engine(py::module &m) { .def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)") .def( "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { + // IMPORTANT: Release the GIL before calling forward() to allow other Python threads + // to run concurrently during inference (which may block for a long time). + // Do NOT remove this — without it, the GIL is held throughout inference and will + // deadlock or stall any other Python thread (e.g., request handling, scheduling). py::gil_scoped_release release; return self.forward(input); }, @@ -83,8 +91,8 @@ inline void bind_infer_engine(py::module &m) { .def( "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { - auto cfg = self.get_cache_config(); - return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) + auto cfg = self.get_cache_config(); + return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) .def("__repr__", [](const InferEngine &self) { return ""; }); py::class_(infer_engine, "Input") diff --git a/examples/bench.py b/examples/bench.py index 6672d6d7..29d15211 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -169,6 +169,7 @@ def __init__( cache_config=None, enable_graph=False, attn_backend="default", + max_num_batched_tokens: int = None, ) -> None: model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -182,6 +183,7 @@ def __init__( enable_graph_compiling=enable_graph, attention_backend=attn_backend, kv_cache_dtype=cfg.kv_cache_dtype, + max_num_batched_tokens=max_num_batched_tokens, ) # ---------------------------------------------------------------------------- # @@ -281,6 +283,7 @@ def run( enable_paged_attn = cfg.enable_paged_attn enable_graph = cfg.enable_graph attn_backend = cfg.attn + max_num_batched_tokens = cfg.max_num_batched_tokens if isinstance(batch_size, int): batch_size = [batch_size] @@ -322,6 +325,7 @@ def run( cache_config=cache_config, enable_graph=enable_graph, attn_backend=attn_backend, + max_num_batched_tokens=max_num_batched_tokens, ) # ---------------------------------------------------------------------------- # diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd45..c0fd306b 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -44,7 +44,6 @@ class BaseConfig: """InfiniLM Unified Config - Command line argument parser""" def __init__(self): - self.parser = argparse.ArgumentParser(description="InfiniLM Unified Config") self._add_common_args() self.args, self.extra = self.parser.parse_known_args() @@ -70,6 +69,7 @@ def __init__(self): self.batch_size = self.args.batch_size self.max_batch_size = self.args.max_batch_size + self.max_num_batched_tokens = self.args.max_num_batched_tokens self.input_len = self.args.input_len self.output_len = self.args.output_len self.max_new_tokens = self.args.max_new_tokens @@ -155,6 +155,12 @@ def _add_common_args(self): default=8, help="maximum batch size for server", ) + self.parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=None, + help="maximum number of batched tokens for paged attention", + ) self.parser.add_argument( "--input-len", type=parse_list, default=10, help="input sequence length" ) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 844989f4..5de13a58 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -30,6 +30,7 @@ def read_hf_config(model_path): ) return config_dict + # config.json (required) defines model architecture, while generation_config.json # (optional) defines generation behavior. They are kept as separate readers # because: 1) config.json must exist and requires model_type validation, @@ -43,6 +44,7 @@ def read_hf_generation_config(model_path): return json.load(f) return {} + @dataclass class GenerationConfig: max_new_tokens: int | None = None @@ -65,6 +67,7 @@ def __init__( enable_graph_compiling=False, attention_backend="default", kv_cache_dtype=None, + max_num_batched_tokens: int | None = None, ): self.hf_config = read_hf_config(model_path) self.hf_generation_config = read_hf_generation_config(model_path) @@ -72,6 +75,12 @@ def __init__( if device is None: device = infinicore.device() + max_position_embeddings = self.hf_config["max_position_embeddings"] + if max_num_batched_tokens is None: + max_num_batched_tokens = max_position_embeddings + assert 512 <= max_num_batched_tokens <= max_position_embeddings + self.max_num_batched_tokens = max_num_batched_tokens + hf_config_str = json.dumps(self.hf_config) super().__init__( hf_config_str, @@ -85,6 +94,7 @@ def __init__( if kv_cache_dtype is not None else None ), + max_num_batched_tokens, ) self.use_cache = False @@ -375,10 +385,14 @@ def reset_cache(self, cache_config): super().reset_cache(cache_config) def state_dict_keyname(self): - return sorted({name for state_dict in super().state_dict() for name in state_dict.keys()}) + return sorted( + {name for state_dict in super().state_dict() for name in state_dict.keys()} + ) def load_state_dict(self, state_dict, strict=None): - super().load_params({name: param._underlying for name, param in state_dict.items()}) + super().load_params( + {name: param._underlying for name, param in state_dict.items()} + ) def process_weights_after_loading(self): super().process_weights_after_loading() diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 953b35e5..dbe9baa6 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -6,6 +6,7 @@ - AsyncLLM class for asynchronous streaming (server use) """ +import os import asyncio import time import uuid @@ -74,6 +75,7 @@ class EngineConfig: enable_graph: bool = False attn_backend: str = "default" skip_load: bool = False + max_num_batched_tokens: int | None = None class LLMEngine: @@ -92,7 +94,9 @@ def __init__(self, config: EngineConfig): distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, attention_backend=config.attn_backend, + max_num_batched_tokens=config.max_num_batched_tokens, ) + self.max_num_batched_tokens = self.model_engine.max_num_batched_tokens # Load model weights if not self.config.skip_load: @@ -121,6 +125,7 @@ def __init__(self, config: EngineConfig): max_batch_size=config.max_batch_size, num_blocks=config.num_blocks, block_size=config.block_size, + max_num_batched_tokens=self.max_num_batched_tokens, ) logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}") else: diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index c5f4921a..b3852073 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -39,6 +39,7 @@ def __init__( max_batch_size: int = 16, num_blocks: int = 512, block_size: int = 256, + max_num_batched_tokens: int = 1024, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() @@ -47,6 +48,8 @@ def __init__( self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) self.block_size = block_size + self.max_num_batched_tokens = max_num_batched_tokens + def add_request(self, request: InferenceRequest): if request is not None: request.status = RequestStatus.WAITING @@ -56,9 +59,13 @@ def schedule(self) -> Optional[SchedulerOutput]: """Schedule and return batch of requests to execute.""" scheduled_requests = [] is_prefill = False + current_num_batched_tokens = 0 # Process Waiting queue (prefill phase) - while len(scheduled_requests) < self.max_batch_size: + while ( + len(scheduled_requests) < self.max_batch_size + and current_num_batched_tokens < self.max_num_batched_tokens + ): try: req = self.waiting_queue.sync_q.get_nowait() except queue.Empty: @@ -91,6 +98,21 @@ def schedule(self) -> Optional[SchedulerOutput]: ) ) + num_tokens_this_step = req.get_prompt_length() - req.num_cached_tokens + if ( + current_num_batched_tokens + num_tokens_this_step + >= self.max_num_batched_tokens + ): + if req.num_cached_tokens > 0: + self.cache_manager.free_blocks(req.block_table) + req.block_table = [] + req.slot_mapping = [] + req.num_cached_tokens = 0 + + self.waiting_queue.sync_q.put(req) + break + + current_num_batched_tokens += num_tokens_this_step req.num_blocks = len(req.block_table) req.status = RequestStatus.RUNNING scheduled_requests.append(req) diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index c15c950f..476e45de 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -55,6 +55,7 @@ def __init__( enable_paged_attn=False, enable_graph=False, attn_backend="default", + max_num_batched_tokens: int | None = None, ): import transformers import infinicore @@ -119,6 +120,7 @@ def __init__( ), enable_graph_compiling=enable_graph, attention_backend=attn_backend, + max_num_batched_tokens=max_num_batched_tokens, ) # Enable KV cache for generation @@ -1125,6 +1127,7 @@ def main(): cfg.bench, cfg.enable_paged_attn, cfg.enable_graph, + cfg.max_num_batched_tokens, cfg.attn, )