Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/base/paged_caching_infinilm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#ifndef INFINI_OPS_BASE_PAGED_CACHING_INFINILM_H_
#define INFINI_OPS_BASE_PAGED_CACHING_INFINILM_H_

#include <cassert>
#include <cstddef>

#include "data_type.h"
#include "operator.h"
#include "tensor.h"

namespace infini::ops {

class PagedCachingInfinilm : public Operator<PagedCachingInfinilm> {
public:
PagedCachingInfinilm(const Tensor k, const Tensor v,
const Tensor slot_mapping, Tensor k_cache,
Tensor v_cache)
: dtype_{k.dtype()},
num_tokens_{slot_mapping.size(0)},
num_kv_heads_{k.size(1)},
head_size_{k.size(2)},
block_size_{k_cache.size(2)},
k_src_stride_{k.stride(0)},
v_src_stride_{v.stride(0)},
k_cache_block_stride_{k_cache.stride(0)},
v_cache_block_stride_{v_cache.stride(0)},
k_cache_head_stride_{k_cache.stride(1)},
v_cache_head_stride_{v_cache.stride(1)},
k_cache_slot_stride_{k_cache.stride(2)},
v_cache_slot_stride_{v_cache.stride(2)} {
assert(k.ndim() == 3 && v.ndim() == 3 &&
"`PagedCachingInfinilm` requires `k` and `v` to be 3D");
assert(k_cache.ndim() == 4 && v_cache.ndim() == 4 &&
"`PagedCachingInfinilm` requires 4D cache tensors");
assert(slot_mapping.ndim() == 1 &&
"`PagedCachingInfinilm` requires 1D slot mapping");
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
dtype_ == DataType::kFloat32) &&
"`PagedCachingInfinilm` supports float16, bfloat16, and float32");
assert(v.dtype() == dtype_ && k_cache.dtype() == dtype_ &&
v_cache.dtype() == dtype_);
assert(slot_mapping.dtype() == DataType::kInt64 &&
"`PagedCachingInfinilm` requires int64 slot mapping");
assert(k.shape() == v.shape());
assert(k_cache.shape() == v_cache.shape());
assert(k_cache.size(1) == num_kv_heads_ && k_cache.size(3) == head_size_);
assert(k.stride(2) == 1 && v.stride(2) == 1);
assert(k_cache.stride(3) == 1 && v_cache.stride(3) == 1);
}

virtual void operator()(const Tensor k, const Tensor v,
const Tensor slot_mapping, Tensor k_cache,
Tensor v_cache) const = 0;

protected:
DataType dtype_;

std::size_t num_tokens_{0};

std::size_t num_kv_heads_{0};

std::size_t head_size_{0};

std::size_t block_size_{0};

Tensor::Stride k_src_stride_{0};

Tensor::Stride v_src_stride_{0};

Tensor::Stride k_cache_block_stride_{0};

Tensor::Stride v_cache_block_stride_{0};

Tensor::Stride k_cache_head_stride_{0};

Tensor::Stride v_cache_head_stride_{0};

Tensor::Stride k_cache_slot_stride_{0};

Tensor::Stride v_cache_slot_stride_{0};
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/iluvatar/ops/paged_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_ILUVATAR_PAGED_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_ILUVATAR_PAGED_CACHING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/paged_caching_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<PagedCachingInfinilm, Device::Type::kIluvatar>
: public CudaPagedCachingInfinilm<Runtime<Device::Type::kIluvatar>> {
public:
using CudaPagedCachingInfinilm<
Runtime<Device::Type::kIluvatar>>::CudaPagedCachingInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/metax/ops/paged_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_METAX_PAGED_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_METAX_PAGED_CACHING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/paged_caching_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<PagedCachingInfinilm, Device::Type::kMetax>
: public CudaPagedCachingInfinilm<Runtime<Device::Type::kMetax>> {
public:
using CudaPagedCachingInfinilm<
Runtime<Device::Type::kMetax>>::CudaPagedCachingInfinilm;
};

} // namespace infini::ops

#endif
30 changes: 30 additions & 0 deletions src/native/cuda/moore/ops/paged_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef INFINI_OPS_MOORE_PAGED_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_MOORE_PAGED_CACHING_INFINILM_KERNEL_H_

#include <utility>

// clang-format off
#include <musa_runtime.h>
// clang-format on

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/paged_caching_infinilm/kernel.h"

namespace infini::ops {

struct MoorePagedCachingInfinilmBackend : Runtime<Device::Type::kMoore> {
static constexpr int max_block_size = 1024;
};

template <>
class Operator<PagedCachingInfinilm, Device::Type::kMoore>
: public CudaPagedCachingInfinilm<MoorePagedCachingInfinilmBackend> {
public:
using CudaPagedCachingInfinilm<
MoorePagedCachingInfinilmBackend>::CudaPagedCachingInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/nvidia/ops/paged_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_NVIDIA_PAGED_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_NVIDIA_PAGED_CACHING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/paged_caching_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<PagedCachingInfinilm, Device::Type::kNvidia>
: public CudaPagedCachingInfinilm<Runtime<Device::Type::kNvidia>> {
public:
using CudaPagedCachingInfinilm<
Runtime<Device::Type::kNvidia>>::CudaPagedCachingInfinilm;
};

} // namespace infini::ops

#endif
49 changes: 49 additions & 0 deletions src/native/cuda/ops/paged_caching_infinilm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef INFINI_OPS_CUDA_PAGED_CACHING_INFINILM_KERNEL_CUH_
#define INFINI_OPS_CUDA_PAGED_CACHING_INFINILM_KERNEL_CUH_

#include <cstddef>
#include <cstdint>

namespace infini::ops {

template <typename Tdata, int kBlockSize>
__global__ void PagedCachingInfinilmKernel(
Tdata* __restrict__ k_cache, Tdata* __restrict__ v_cache,
const Tdata* __restrict__ k, const Tdata* __restrict__ v,
const int64_t* __restrict__ slot_mapping, std::size_t head_size,
std::size_t cache_block_size, std::ptrdiff_t k_src_stride,
std::ptrdiff_t v_src_stride, std::ptrdiff_t k_cache_block_stride,
std::ptrdiff_t v_cache_block_stride, std::ptrdiff_t k_cache_head_stride,
std::ptrdiff_t v_cache_head_stride, std::ptrdiff_t k_cache_slot_stride,
std::ptrdiff_t v_cache_slot_stride) {
auto head_idx = static_cast<std::size_t>(blockIdx.x);
auto token_idx = static_cast<std::size_t>(blockIdx.y);
int64_t slot = slot_mapping[token_idx];

if (slot < 0) {
return;
}

auto physical_block_idx = static_cast<std::size_t>(slot) / cache_block_size;
auto block_offset = static_cast<std::size_t>(slot) % cache_block_size;

const Tdata* k_src = k + token_idx * k_src_stride +
head_idx * static_cast<std::ptrdiff_t>(head_size);
const Tdata* v_src = v + token_idx * v_src_stride +
head_idx * static_cast<std::ptrdiff_t>(head_size);
Tdata* k_dst = k_cache + physical_block_idx * k_cache_block_stride +
head_idx * k_cache_head_stride +
block_offset * k_cache_slot_stride;
Tdata* v_dst = v_cache + physical_block_idx * v_cache_block_stride +
head_idx * v_cache_head_stride +
block_offset * v_cache_slot_stride;

for (std::size_t i = threadIdx.x; i < head_size; i += kBlockSize) {
k_dst[i] = k_src[i];
v_dst[i] = v_src[i];
}
}

} // namespace infini::ops

#endif
63 changes: 63 additions & 0 deletions src/native/cuda/ops/paged_caching_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#ifndef INFINI_OPS_CUDA_PAGED_CACHING_INFINILM_KERNEL_H_
#define INFINI_OPS_CUDA_PAGED_CACHING_INFINILM_KERNEL_H_

#include <algorithm>
#include <cassert>
#include <cstdint>

#include "base/paged_caching_infinilm.h"
#include "data_type.h"
#include "dispatcher.h"
#include "native/cuda/kernel_commons.cuh"
#include "native/cuda/ops/paged_caching_infinilm/kernel.cuh"
#include "native/cuda/runtime_utils.h"

namespace infini::ops {

template <typename Backend>
class CudaPagedCachingInfinilm : public PagedCachingInfinilm {
public:
using PagedCachingInfinilm::PagedCachingInfinilm;

void operator()(const Tensor k, const Tensor v, const Tensor slot_mapping,
Tensor k_cache, Tensor v_cache) const override {
assert(k.dtype() == dtype_ && v.dtype() == dtype_ &&
k_cache.dtype() == dtype_ && v_cache.dtype() == dtype_);
assert(slot_mapping.dtype() == DataType::kInt64);

auto cuda_stream =
static_cast<typename Backend::Stream>(stream_ ? stream_ : 0);
int block_size =
std::min(RuntimeUtils<Backend::kDeviceType>::GetOptimalBlockSize(),
BackendMaxBlockSize<Backend>::value);

dim3 grid(static_cast<unsigned>(num_kv_heads_),
static_cast<unsigned>(num_tokens_));

DispatchFunc<
ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
SupportedCudaBlockSizesType<BackendMaxBlockSize<Backend>::value>>(
{static_cast<int64_t>(dtype_), block_size},
[&](auto list_tag) {
using T = TypeMapType<Backend::kDeviceType, ListGet<0>(list_tag)>;
constexpr int kBlockSize = ListGet<1>(list_tag);

PagedCachingInfinilmKernel<T, kBlockSize>
<<<grid, kBlockSize, 0, cuda_stream>>>(
reinterpret_cast<T*>(k_cache.data()),
reinterpret_cast<T*>(v_cache.data()),
reinterpret_cast<const T*>(k.data()),
reinterpret_cast<const T*>(v.data()),
reinterpret_cast<const int64_t*>(slot_mapping.data()),
head_size_, block_size_, k_src_stride_, v_src_stride_,
k_cache_block_stride_, v_cache_block_stride_,
k_cache_head_stride_, v_cache_head_stride_,
k_cache_slot_stride_, v_cache_slot_stride_);
},
"CudaPagedCachingInfinilm::operator()");
}
};

} // namespace infini::ops

#endif
Loading
Loading