Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 7 additions & 218 deletions csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

#include "../global_state/global_state.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include <stdexcept>

namespace infinilm::cache {
// ==========================
Expand Down Expand Up @@ -32,58 +30,12 @@ StaticKVCacheConfig::max_cache_len() const {
return max_cache_len_;
}

namespace StaticKVCache {

// ==========================
// StaticKVCache
// ==========================

StaticKVCache::StaticKVCache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
k_dim_(k_dim),
v_dim_(v_dim),
rank_batch_size_(config.max_batch_size()),
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
rank_num_layers_(num_layers),
dtype_(dtype) {

bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);

num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
// Allocate K cache
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
rank_batch_size_,
num_rank_k_heads_,
cache_len_,
k_dim_},
dtype_,
rank_info.device);
set_zeros(k_caches_);

// Allocate V cache
v_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
rank_batch_size_,
num_rank_v_heads_,
cache_len_,
v_dim_},
dtype_,
rank_info.device);
set_zeros(v_caches_);

infinicore::context::syncStream();
}

infinicore::Tensor StaticKVCache::create_layer_kv_cache(
infinicore::Tensor create_layer_kv_cache(
const infinicore::Size k_dim,
const infinicore::Size v_dim,
const infinicore::Size num_k_heads,
Expand Down Expand Up @@ -120,45 +72,7 @@ infinicore::Tensor StaticKVCache::create_layer_kv_cache(

return kv_cache;
}

std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache::update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &past_sequence_lengths) {
ASSERT(layer_idx < rank_num_layers_);

auto batch_size = k->size(0);
auto update_len = k->size(2);

ASSERT_EQ(batch_size, rank_batch_size_);

auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);

auto device = k_cache_layer->device();

#ifdef ENABLE_KV_CACHING
infinicore::op::kv_caching_(
k_cache_layer,
v_cache_layer,
k,
v,
past_sequence_lengths);
#else
size_t cache_pos = reinterpret_cast<int32_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);

auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});

k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
#endif

return {k_cache_layer, v_cache_layer};
}
}; // namespace StaticKVCache

// ==========================
// PagedKVCacheConfig
Expand All @@ -185,56 +99,11 @@ PagedKVCacheConfig::block_size() const {
return block_size_;
}

namespace PagedKVCache {
// ==========================
// PagedKVCache
// ==========================
PagedKVCache::PagedKVCache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
k_dim_(k_dim),
v_dim_(v_dim),
rank_num_layers_(num_layers),
dtype_(dtype),
num_blocks_per_layer_(config.num_blocks()),
block_size_(config.block_size()) {

bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);

num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
num_blocks_per_layer_,
num_rank_k_heads_,
block_size_,
k_dim_},
dtype_,
rank_info.device);
set_zeros(k_caches_);

// [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim]
v_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
num_blocks_per_layer_,
num_rank_v_heads_,
block_size_,
v_dim_},
dtype_,
rank_info.device);
set_zeros(v_caches_);

infinicore::context::syncStream();
}

infinicore::Tensor PagedKVCache::create_layer_kv_cache(
infinicore::Tensor create_layer_kv_cache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
Expand Down Expand Up @@ -273,86 +142,6 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(

return kv_cache;
}

std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping) {

auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);

infinicore::op::paged_caching_(
k_cache_layer,
v_cache_layer,
k,
v,
slot_mapping);
return {k_cache_layer, v_cache_layer};
}

std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_paged_kv(size_t layer_idx) {
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
return {k_cache_layer, v_cache_layer};
}

std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_contiguous_kv(
size_t layer_idx,
const infinicore::Tensor block_tables,
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id) {
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I32);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I32);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I32);

auto nreq = block_tables->size(0);
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
auto cache_lens_cpu = cache_lens->to(infinicore::Device::cpu());
auto input_offsets_cpu = input_offsets->to(infinicore::Device::cpu());
infinicore::context::syncDevice();

// [num_blocks, num_rank_v_heads, block_size, v_dim]
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);

auto req = request_id;
auto cache_lens_ptr = reinterpret_cast<const int32_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int32_t *>(input_offsets_cpu->data());
int32_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);

auto full_k = infinicore::Tensor::empty(
{num_rank_k_heads_, (size_t)total_len, k_dim_},
k_cache_layer->dtype(), k_cache_layer->device());

auto full_v = infinicore::Tensor::empty(
{num_rank_v_heads_, (size_t)total_len, v_dim_},
v_cache_layer->dtype(), v_cache_layer->device());

size_t nblocks = total_len / block_size_;
size_t r = total_len % block_size_;

for (size_t b = 0; b < nblocks; b++) {
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));

full_k->narrow({{1, b * block_size_, block_size_}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
full_v->narrow({{1, b * block_size_, block_size_}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
}

if (r > 0) {
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));

full_k->narrow({{1, nblocks * block_size_, r}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
full_v->narrow({{1, nblocks * block_size_, r}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
}

return {full_k, full_v};
}
}; // namespace PagedKVCache

} // namespace infinilm::cache
Loading