diff --git a/src/base/kv_caching_infinilm.h b/src/base/kv_caching_infinilm.h new file mode 100644 index 000000000..195c210c1 --- /dev/null +++ b/src/base/kv_caching_infinilm.h @@ -0,0 +1,107 @@ +#ifndef INFINI_OPS_BASE_KV_CACHING_INFINILM_H_ +#define INFINI_OPS_BASE_KV_CACHING_INFINILM_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class KvCachingInfinilm : public Operator { + public: + KvCachingInfinilm(const Tensor k, const Tensor v, + const Tensor past_kv_lengths, Tensor k_cache, + Tensor v_cache) + : k_cache_shape_{k_cache.shape()}, + k_cache_strides_{k_cache.strides()}, + v_cache_shape_{v_cache.shape()}, + v_cache_strides_{v_cache.strides()}, + k_shape_{k.shape()}, + k_strides_{k.strides()}, + v_shape_{v.shape()}, + v_strides_{v.strides()}, + past_kv_lengths_shape_{past_kv_lengths.shape()}, + data_type_{k_cache.dtype()}, + past_kv_lengths_type_{past_kv_lengths.dtype()}, + batch_size_{k_cache.size(0)}, + num_kv_heads_{k_cache.size(1)}, + max_seq_len_{k_cache.size(2)}, + seq_len_{k.size(2)}, + hidden_size_{k_cache.size(3)}, + output_size_{k.numel()}, + device_index_{k_cache.device().index()} { + assert(k_cache.ndim() == 4 && v_cache.ndim() == 4 && k.ndim() == 4 && + v.ndim() == 4 && "`KvCachingInfinilm` tensors must be 4D"); + assert(k_cache_shape_ == v_cache_shape_ && + "`KvCachingInfinilm` cache shapes must match"); + assert(k_shape_ == v_shape_ && + "`KvCachingInfinilm` source shapes must match"); + assert(k.size(0) == batch_size_ && k.size(1) == num_kv_heads_ && + k.size(3) == hidden_size_ && + "`KvCachingInfinilm` source shape must match cache " + "batch/head/hidden dims"); + assert(seq_len_ <= max_seq_len_ && + "`KvCachingInfinilm` source sequence length exceeds cache length"); + assert(k_cache.dtype() == v_cache.dtype() && k_cache.dtype() == k.dtype() && + k_cache.dtype() == v.dtype() && + "`KvCachingInfinilm` K/V tensors must have the same dtype"); + assert( + (data_type_ == DataType::kFloat16 || + data_type_ == DataType::kBFloat16 || + data_type_ == DataType::kFloat32) && + "`KvCachingInfinilm` K/V dtype must be float16, bfloat16, or float32"); + assert((past_kv_lengths_type_ == DataType::kInt32 || + past_kv_lengths_type_ == DataType::kInt64) && + "`KvCachingInfinilm` past_kv_lengths dtype must be int32 or int64"); + assert(past_kv_lengths.ndim() == 1 && + past_kv_lengths.size(0) == batch_size_ && + "`KvCachingInfinilm` past_kv_lengths shape must be (batch_size,)"); + assert(!k_cache.HasBroadcastDim() && !v_cache.HasBroadcastDim() && + "`KvCachingInfinilm` caches must not have broadcasted dimensions"); + } + + virtual void operator()(const Tensor k, const Tensor v, + const Tensor past_kv_lengths, Tensor k_cache, + Tensor v_cache) const = 0; + + protected: + Tensor::Shape k_cache_shape_; + + Tensor::Strides k_cache_strides_; + + Tensor::Shape v_cache_shape_; + + Tensor::Strides v_cache_strides_; + + Tensor::Shape k_shape_; + + Tensor::Strides k_strides_; + + Tensor::Shape v_shape_; + + Tensor::Strides v_strides_; + + Tensor::Shape past_kv_lengths_shape_; + + DataType data_type_; + + DataType past_kv_lengths_type_; + + Tensor::Size batch_size_{0}; + + Tensor::Size num_kv_heads_{0}; + + Tensor::Size max_seq_len_{0}; + + Tensor::Size seq_len_{0}; + + Tensor::Size hidden_size_{0}; + + Tensor::Size output_size_{0}; + + int device_index_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/iluvatar/ops/kv_caching_infinilm/kernel.h b/src/native/cuda/iluvatar/ops/kv_caching_infinilm/kernel.h new file mode 100644 index 000000000..3f3018864 --- /dev/null +++ b/src/native/cuda/iluvatar/ops/kv_caching_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_ILUVATAR_KV_CACHING_INFINILM_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_KV_CACHING_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/iluvatar/caster.cuh" +#include "native/cuda/iluvatar/runtime_.h" +#include "native/cuda/ops/kv_caching_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaKvCachingInfinilm> { + public: + using CudaKvCachingInfinilm< + Runtime>::CudaKvCachingInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/metax/ops/kv_caching_infinilm/kernel.h b/src/native/cuda/metax/ops/kv_caching_infinilm/kernel.h new file mode 100644 index 000000000..92a24226f --- /dev/null +++ b/src/native/cuda/metax/ops/kv_caching_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_METAX_KV_CACHING_INFINILM_KERNEL_H_ +#define INFINI_OPS_METAX_KV_CACHING_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/metax/caster.cuh" +#include "native/cuda/metax/runtime_.h" +#include "native/cuda/ops/kv_caching_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaKvCachingInfinilm> { + public: + using CudaKvCachingInfinilm< + Runtime>::CudaKvCachingInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/moore/ops/kv_caching_infinilm/kernel.h b/src/native/cuda/moore/ops/kv_caching_infinilm/kernel.h new file mode 100644 index 000000000..223cd8049 --- /dev/null +++ b/src/native/cuda/moore/ops/kv_caching_infinilm/kernel.h @@ -0,0 +1,23 @@ +#ifndef INFINI_OPS_MOORE_KV_CACHING_INFINILM_KERNEL_H_ +#define INFINI_OPS_MOORE_KV_CACHING_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/moore/caster.cuh" +#include "native/cuda/moore/polyfills.cuh" +#include "native/cuda/moore/runtime_.h" +#include "native/cuda/ops/kv_caching_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaKvCachingInfinilm> { + public: + using CudaKvCachingInfinilm< + Runtime>::CudaKvCachingInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/nvidia/ops/kv_caching_infinilm/kernel.h b/src/native/cuda/nvidia/ops/kv_caching_infinilm/kernel.h new file mode 100644 index 000000000..b4e7405bf --- /dev/null +++ b/src/native/cuda/nvidia/ops/kv_caching_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_NVIDIA_KV_CACHING_INFINILM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_KV_CACHING_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/kv_caching_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaKvCachingInfinilm> { + public: + using CudaKvCachingInfinilm< + Runtime>::CudaKvCachingInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/kv_caching_infinilm/kernel.cuh b/src/native/cuda/ops/kv_caching_infinilm/kernel.cuh new file mode 100644 index 000000000..5704adb22 --- /dev/null +++ b/src/native/cuda/ops/kv_caching_infinilm/kernel.cuh @@ -0,0 +1,48 @@ +#ifndef INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_CUH_ + +#include +#include + +namespace infini::ops { + +template +__global__ void KvCachingInfinilmKernel( + T* __restrict__ k_cache, T* __restrict__ v_cache, const T* __restrict__ k, + const T* __restrict__ v, const TIndex* __restrict__ past_kv_lengths, + const ptrdiff_t* __restrict__ k_cache_strides, + const ptrdiff_t* __restrict__ v_cache_strides, + const ptrdiff_t* __restrict__ k_strides, + const ptrdiff_t* __restrict__ v_strides, size_t output_size, + size_t num_kv_heads, size_t seq_len, size_t hidden_size) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (size_t idx = tid; idx < output_size; idx += blockDim.x * gridDim.x) { + size_t offset = idx; + size_t d = offset % hidden_size; + offset /= hidden_size; + size_t s = offset % seq_len; + offset /= seq_len; + size_t h = offset % num_kv_heads; + size_t b = offset / num_kv_heads; + + size_t cache_s = static_cast(past_kv_lengths[b]) + s; + ptrdiff_t k_cache_offset = b * k_cache_strides[0] + h * k_cache_strides[1] + + cache_s * k_cache_strides[2] + + d * k_cache_strides[3]; + ptrdiff_t v_cache_offset = b * v_cache_strides[0] + h * v_cache_strides[1] + + cache_s * v_cache_strides[2] + + d * v_cache_strides[3]; + ptrdiff_t k_offset = b * k_strides[0] + h * k_strides[1] + + s * k_strides[2] + d * k_strides[3]; + ptrdiff_t v_offset = b * v_strides[0] + h * v_strides[1] + + s * v_strides[2] + d * v_strides[3]; + + k_cache[k_cache_offset] = k[k_offset]; + v_cache[v_cache_offset] = v[v_offset]; + } +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/kv_caching_infinilm/kernel.h b/src/native/cuda/ops/kv_caching_infinilm/kernel.h new file mode 100644 index 000000000..79c836dc6 --- /dev/null +++ b/src/native/cuda/ops/kv_caching_infinilm/kernel.h @@ -0,0 +1,105 @@ +#ifndef INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_H_ +#define INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_H_ + +#include +#include +#include +#include + +#include "base/kv_caching_infinilm.h" +#include "common/generic_utils.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/ops/kv_caching_infinilm/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaKvCachingInfinilm : public KvCachingInfinilm { + public: + CudaKvCachingInfinilm(const Tensor k, const Tensor v, + const Tensor past_kv_lengths, Tensor k_cache, + Tensor v_cache) + : KvCachingInfinilm{k, v, past_kv_lengths, k_cache, v_cache} { + constexpr size_t ndim = 4; + size_t strides_size = ndim * sizeof(*d_k_cache_strides_); + const size_t metadata_size = 4 * strides_size; + std::vector metadata(metadata_size); + + Backend::Malloc((void**)&d_metadata_, metadata_size); + + size_t offset = 0; + d_k_cache_strides_ = + reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, k_cache_strides_.data(), + strides_size); + offset += strides_size; + + d_v_cache_strides_ = + reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, v_cache_strides_.data(), + strides_size); + offset += strides_size; + + d_k_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, k_strides_.data(), strides_size); + offset += strides_size; + + d_v_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, v_strides_.data(), strides_size); + + Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, + Backend::MemcpyHostToDevice); + } + + ~CudaKvCachingInfinilm() { Backend::Free(d_metadata_); } + + void operator()(const Tensor k, const Tensor v, const Tensor past_kv_lengths, + Tensor k_cache, Tensor v_cache) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + int block_size = std::min( + RuntimeUtils::GetOptimalBlockSize(), 1024); + dim3 block(std::min(static_cast(block_size), output_size_)); + dim3 grid(utils::CeilDiv(output_size_, block.x)); + + using IndexTypes = List; + DispatchFunc>( + {static_cast(data_type_), + static_cast(past_kv_lengths_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + using TIndex = + TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<2>(list_tag); + + KvCachingInfinilmKernel + <<>>( + reinterpret_cast(k_cache.data()), + reinterpret_cast(v_cache.data()), + reinterpret_cast(k.data()), + reinterpret_cast(v.data()), + reinterpret_cast(past_kv_lengths.data()), + d_k_cache_strides_, d_v_cache_strides_, d_k_strides_, + d_v_strides_, output_size_, num_kv_heads_, seq_len_, + hidden_size_); + }, + "CudaKvCachingInfinilm::operator()"); + } + + private: + std::byte* d_metadata_{nullptr}; + + Tensor::Stride* d_k_cache_strides_{nullptr}; + + Tensor::Stride* d_v_cache_strides_{nullptr}; + + Tensor::Stride* d_k_strides_{nullptr}; + + Tensor::Stride* d_v_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_kv_caching_infinilm.py b/tests/test_kv_caching_infinilm.py new file mode 100644 index 000000000..5fee28508 --- /dev/null +++ b/tests/test_kv_caching_infinilm.py @@ -0,0 +1,65 @@ +import infini.ops +import pytest +import torch + +from tests.utils import clone_strided, empty_strided, get_stream, randn_strided + + +@pytest.mark.parametrize( + "cache_shape, cache_strides, seq_len", + ( + ((1, 1, 8, 1), None, 3), + ((1, 8, 32, 32), None, 7), + ((8, 8, 64, 32), None, 5), + ((1, 32, 8, 64), (32768, 1024, 64, 1), 4), + ((4, 8, 32, 16), (65536, 8192, 256, 16), 7), + ((8, 16, 64, 128), (8388608, 524288, 8192, 1), 3), + ((1, 2, 2304, 128), (589824, 294912, 128, 1), 9), + ), +) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) +@pytest.mark.parametrize("index_dtype", (torch.int64, torch.int32)) +def test_kv_caching_infinilm( + cache_shape, cache_strides, seq_len, dtype, index_dtype, device +): + batch, heads, max_seq_len, hidden = cache_shape + k_cache = randn_strided(cache_shape, cache_strides, dtype=dtype, device=device) + v_cache = randn_strided(cache_shape, cache_strides, dtype=dtype, device=device) + k = randn_strided((batch, heads, seq_len, hidden), None, dtype=dtype, device=device) + v = randn_strided((batch, heads, seq_len, hidden), None, dtype=dtype, device=device) + past_kv_lengths = _make_past_lengths( + batch, max_seq_len - seq_len, dtype=index_dtype, device=device + ) + + expected_k_cache = clone_strided(k_cache) + expected_v_cache = clone_strided(v_cache) + _torch_kv_caching_infinilm( + expected_k_cache, expected_v_cache, k, v, past_kv_lengths + ) + + infini.ops.kv_caching_infinilm( + k, + v, + past_kv_lengths, + k_cache, + v_cache, + stream=get_stream(k_cache.device), + ) + + torch.testing.assert_close(k_cache, expected_k_cache, rtol=0, atol=0) + torch.testing.assert_close(v_cache, expected_v_cache, rtol=0, atol=0) + + +def _make_past_lengths(batch, high, *, dtype, device): + values = torch.arange(batch, dtype=torch.int64, device=device) % max(high, 1) + return values.to(dtype) + + +def _torch_kv_caching_infinilm(k_cache, v_cache, k, v, past_kv_lengths): + batch, heads, _, _ = k_cache.shape + seq_len = k.shape[2] + for b in range(batch): + past_len = int(past_kv_lengths[b].item()) + for h in range(heads): + k_cache[b, h, past_len : past_len + seq_len, :] = k[b, h, :, :] + v_cache[b, h, past_len : past_len + seq_len, :] = v[b, h, :, :]