Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions src/base/kv_caching_infinilm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#ifndef INFINI_OPS_BASE_KV_CACHING_INFINILM_H_
#define INFINI_OPS_BASE_KV_CACHING_INFINILM_H_

#include <cassert>

#include "operator.h"

namespace infini::ops {

class KvCachingInfinilm : public Operator<KvCachingInfinilm> {
public:
KvCachingInfinilm(const Tensor k, const Tensor v,
const Tensor past_kv_lengths, Tensor k_cache,
Tensor v_cache)
: k_cache_shape_{k_cache.shape()},
k_cache_strides_{k_cache.strides()},
v_cache_shape_{v_cache.shape()},
v_cache_strides_{v_cache.strides()},
k_shape_{k.shape()},
k_strides_{k.strides()},
v_shape_{v.shape()},
v_strides_{v.strides()},
past_kv_lengths_shape_{past_kv_lengths.shape()},
data_type_{k_cache.dtype()},
past_kv_lengths_type_{past_kv_lengths.dtype()},
batch_size_{k_cache.size(0)},
num_kv_heads_{k_cache.size(1)},
max_seq_len_{k_cache.size(2)},
seq_len_{k.size(2)},
hidden_size_{k_cache.size(3)},
output_size_{k.numel()},
device_index_{k_cache.device().index()} {
assert(k_cache.ndim() == 4 && v_cache.ndim() == 4 && k.ndim() == 4 &&
v.ndim() == 4 && "`KvCachingInfinilm` tensors must be 4D");
assert(k_cache_shape_ == v_cache_shape_ &&
"`KvCachingInfinilm` cache shapes must match");
assert(k_shape_ == v_shape_ &&
"`KvCachingInfinilm` source shapes must match");
assert(k.size(0) == batch_size_ && k.size(1) == num_kv_heads_ &&
k.size(3) == hidden_size_ &&
"`KvCachingInfinilm` source shape must match cache "
"batch/head/hidden dims");
assert(seq_len_ <= max_seq_len_ &&
"`KvCachingInfinilm` source sequence length exceeds cache length");
assert(k_cache.dtype() == v_cache.dtype() && k_cache.dtype() == k.dtype() &&
k_cache.dtype() == v.dtype() &&
"`KvCachingInfinilm` K/V tensors must have the same dtype");
assert(
(data_type_ == DataType::kFloat16 ||
data_type_ == DataType::kBFloat16 ||
data_type_ == DataType::kFloat32) &&
"`KvCachingInfinilm` K/V dtype must be float16, bfloat16, or float32");
assert((past_kv_lengths_type_ == DataType::kInt32 ||
past_kv_lengths_type_ == DataType::kInt64) &&
"`KvCachingInfinilm` past_kv_lengths dtype must be int32 or int64");
assert(past_kv_lengths.ndim() == 1 &&
past_kv_lengths.size(0) == batch_size_ &&
"`KvCachingInfinilm` past_kv_lengths shape must be (batch_size,)");
assert(!k_cache.HasBroadcastDim() && !v_cache.HasBroadcastDim() &&
"`KvCachingInfinilm` caches must not have broadcasted dimensions");
}

virtual void operator()(const Tensor k, const Tensor v,
const Tensor past_kv_lengths, Tensor k_cache,
Tensor v_cache) const = 0;

protected:
Tensor::Shape k_cache_shape_;

Tensor::Strides k_cache_strides_;

Tensor::Shape v_cache_shape_;

Tensor::Strides v_cache_strides_;

Tensor::Shape k_shape_;

Tensor::Strides k_strides_;

Tensor::Shape v_shape_;

Tensor::Strides v_strides_;

Tensor::Shape past_kv_lengths_shape_;

DataType data_type_;

DataType past_kv_lengths_type_;

Tensor::Size batch_size_{0};

Tensor::Size num_kv_heads_{0};

Tensor::Size max_seq_len_{0};

Tensor::Size seq_len_{0};

Tensor::Size hidden_size_{0};

Tensor::Size output_size_{0};

int device_index_{0};
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/iluvatar/ops/kv_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_ILUVATAR_KV_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_ILUVATAR_KV_CACHING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<KvCachingInfinilm, Device::Type::kIluvatar>
: public CudaKvCachingInfinilm<Runtime<Device::Type::kIluvatar>> {
public:
using CudaKvCachingInfinilm<
Runtime<Device::Type::kIluvatar>>::CudaKvCachingInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/metax/ops/kv_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_METAX_KV_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_METAX_KV_CACHING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<KvCachingInfinilm, Device::Type::kMetax>
: public CudaKvCachingInfinilm<Runtime<Device::Type::kMetax>> {
public:
using CudaKvCachingInfinilm<
Runtime<Device::Type::kMetax>>::CudaKvCachingInfinilm;
};

} // namespace infini::ops

#endif
23 changes: 23 additions & 0 deletions src/native/cuda/moore/ops/kv_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef INFINI_OPS_MOORE_KV_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_MOORE_KV_CACHING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/polyfills.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<KvCachingInfinilm, Device::Type::kMoore>
: public CudaKvCachingInfinilm<Runtime<Device::Type::kMoore>> {
public:
using CudaKvCachingInfinilm<
Runtime<Device::Type::kMoore>>::CudaKvCachingInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/nvidia/ops/kv_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_NVIDIA_KV_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_NVIDIA_KV_CACHING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/kv_caching_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<KvCachingInfinilm, Device::Type::kNvidia>
: public CudaKvCachingInfinilm<Runtime<Device::Type::kNvidia>> {
public:
using CudaKvCachingInfinilm<
Runtime<Device::Type::kNvidia>>::CudaKvCachingInfinilm;
};

} // namespace infini::ops

#endif
48 changes: 48 additions & 0 deletions src/native/cuda/ops/kv_caching_infinilm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_CUH_
#define INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_CUH_

#include <cstddef>
#include <cstdint>

namespace infini::ops {

template <typename T, typename TIndex, unsigned int block_size>
__global__ void KvCachingInfinilmKernel(
T* __restrict__ k_cache, T* __restrict__ v_cache, const T* __restrict__ k,
const T* __restrict__ v, const TIndex* __restrict__ past_kv_lengths,
const ptrdiff_t* __restrict__ k_cache_strides,
const ptrdiff_t* __restrict__ v_cache_strides,
const ptrdiff_t* __restrict__ k_strides,
const ptrdiff_t* __restrict__ v_strides, size_t output_size,
size_t num_kv_heads, size_t seq_len, size_t hidden_size) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;

for (size_t idx = tid; idx < output_size; idx += blockDim.x * gridDim.x) {
size_t offset = idx;
size_t d = offset % hidden_size;
offset /= hidden_size;
size_t s = offset % seq_len;
offset /= seq_len;
size_t h = offset % num_kv_heads;
size_t b = offset / num_kv_heads;

size_t cache_s = static_cast<size_t>(past_kv_lengths[b]) + s;
ptrdiff_t k_cache_offset = b * k_cache_strides[0] + h * k_cache_strides[1] +
cache_s * k_cache_strides[2] +
d * k_cache_strides[3];
ptrdiff_t v_cache_offset = b * v_cache_strides[0] + h * v_cache_strides[1] +
cache_s * v_cache_strides[2] +
d * v_cache_strides[3];
ptrdiff_t k_offset = b * k_strides[0] + h * k_strides[1] +
s * k_strides[2] + d * k_strides[3];
ptrdiff_t v_offset = b * v_strides[0] + h * v_strides[1] +
s * v_strides[2] + d * v_strides[3];

k_cache[k_cache_offset] = k[k_offset];
v_cache[v_cache_offset] = v[v_offset];
}
}

} // namespace infini::ops

#endif
105 changes: 105 additions & 0 deletions src/native/cuda/ops/kv_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#ifndef INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_CUDA_KV_CACHING_INFINILM_KERNEL_H_

#include <algorithm>
#include <cstddef>
#include <cstring>
#include <vector>

#include "base/kv_caching_infinilm.h"
#include "common/generic_utils.h"
#include "data_type.h"
#include "dispatcher.h"
#include "native/cuda/ops/kv_caching_infinilm/kernel.cuh"
#include "native/cuda/runtime_utils.h"

namespace infini::ops {

template <typename Backend>
class CudaKvCachingInfinilm : public KvCachingInfinilm {
public:
CudaKvCachingInfinilm(const Tensor k, const Tensor v,
const Tensor past_kv_lengths, Tensor k_cache,
Tensor v_cache)
: KvCachingInfinilm{k, v, past_kv_lengths, k_cache, v_cache} {
constexpr size_t ndim = 4;
size_t strides_size = ndim * sizeof(*d_k_cache_strides_);
const size_t metadata_size = 4 * strides_size;
std::vector<std::byte> metadata(metadata_size);

Backend::Malloc((void**)&d_metadata_, metadata_size);

size_t offset = 0;
d_k_cache_strides_ =
reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, k_cache_strides_.data(),
strides_size);
offset += strides_size;

d_v_cache_strides_ =
reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, v_cache_strides_.data(),
strides_size);
offset += strides_size;

d_k_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, k_strides_.data(), strides_size);
offset += strides_size;

d_v_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, v_strides_.data(), strides_size);

Backend::Memcpy(d_metadata_, metadata.data(), metadata_size,
Backend::MemcpyHostToDevice);
}

~CudaKvCachingInfinilm() { Backend::Free(d_metadata_); }

void operator()(const Tensor k, const Tensor v, const Tensor past_kv_lengths,
Tensor k_cache, Tensor v_cache) const override {
auto cuda_stream =
static_cast<typename Backend::Stream>(stream_ ? stream_ : 0);
int block_size = std::min(
RuntimeUtils<Backend::kDeviceType>::GetOptimalBlockSize(), 1024);
dim3 block(std::min(static_cast<Tensor::Size>(block_size), output_size_));
dim3 grid(utils::CeilDiv(output_size_, block.x));

using IndexTypes = List<DataType::kInt32, DataType::kInt64>;
DispatchFunc<AllFloatTypes, IndexTypes, List<128, 256, 512, 1024>>(
{static_cast<int64_t>(data_type_),
static_cast<int64_t>(past_kv_lengths_type_), block_size},
[&](auto list_tag) {
using T = TypeMapType<Backend::kDeviceType, ListGet<0>(list_tag)>;
using TIndex =
TypeMapType<Backend::kDeviceType, ListGet<1>(list_tag)>;
constexpr int kBlockSize = ListGet<2>(list_tag);

KvCachingInfinilmKernel<T, TIndex, kBlockSize>
<<<grid, block, 0, cuda_stream>>>(
reinterpret_cast<T*>(k_cache.data()),
reinterpret_cast<T*>(v_cache.data()),
reinterpret_cast<const T*>(k.data()),
reinterpret_cast<const T*>(v.data()),
reinterpret_cast<const TIndex*>(past_kv_lengths.data()),
d_k_cache_strides_, d_v_cache_strides_, d_k_strides_,
d_v_strides_, output_size_, num_kv_heads_, seq_len_,
hidden_size_);
},
"CudaKvCachingInfinilm::operator()");
}

private:
std::byte* d_metadata_{nullptr};

Tensor::Stride* d_k_cache_strides_{nullptr};

Tensor::Stride* d_v_cache_strides_{nullptr};

Tensor::Stride* d_k_strides_{nullptr};

Tensor::Stride* d_v_strides_{nullptr};
};

} // namespace infini::ops

#endif
Loading
Loading