diff --git a/src/base/paged_attention_infinilm.h b/src/base/paged_attention_infinilm.h new file mode 100644 index 000000000..627b5e4e7 --- /dev/null +++ b/src/base/paged_attention_infinilm.h @@ -0,0 +1,105 @@ +#ifndef INFINI_OPS_BASE_PAGED_ATTENTION_INFINILM_H_ +#define INFINI_OPS_BASE_PAGED_ATTENTION_INFINILM_H_ + +#include +#include +#include + +#include "data_type.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class PagedAttentionInfinilm : public Operator { + public: + PagedAttentionInfinilm(const Tensor q, const Tensor k_cache, + const Tensor v_cache, const Tensor block_tables, + const Tensor seq_lens, + std::optional alibi_slopes, float scale, + Tensor out) + : dtype_{q.dtype()}, + index_dtype_{block_tables.dtype()}, + scale_{scale}, + num_seqs_{q.size(0)}, + num_heads_{q.size(1)}, + num_kv_heads_{k_cache.size(1)}, + head_size_{q.size(2)}, + block_size_{k_cache.size(2)}, + max_num_blocks_per_seq_{block_tables.size(1)}, + q_stride_{q.stride(0)}, + q_head_stride_{q.stride(1)}, + k_cache_block_stride_{k_cache.stride(0)}, + k_cache_head_stride_{k_cache.stride(1)}, + k_cache_slot_stride_{k_cache.stride(2)}, + v_cache_block_stride_{v_cache.stride(0)}, + v_cache_head_stride_{v_cache.stride(1)}, + v_cache_slot_stride_{v_cache.stride(2)}, + out_stride_{out.stride(0)}, + out_head_stride_{out.stride(1)}, + block_table_batch_stride_{block_tables.stride(0)}, + seq_lens_stride_{seq_lens.stride(0)} { + assert(q.ndim() == 3 && out.ndim() == 3); + assert(k_cache.ndim() == 4 && v_cache.ndim() == 4); + assert(block_tables.ndim() == 2 && seq_lens.ndim() == 1); + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16) && + "`PagedAttentionInfinilm` supports float16 and bfloat16"); + assert(out.dtype() == dtype_ && k_cache.dtype() == dtype_ && + v_cache.dtype() == dtype_); + assert(IsIndexDtype(index_dtype_) && seq_lens.dtype() == index_dtype_); + assert(q.shape() == out.shape()); + assert(k_cache.shape() == v_cache.shape()); + assert(block_tables.size(0) == num_seqs_ && seq_lens.size(0) == num_seqs_); + assert(k_cache.size(1) == num_kv_heads_ && + v_cache.size(1) == num_kv_heads_); + assert(k_cache.size(3) == head_size_ && v_cache.size(3) == head_size_); + assert((head_size_ == 64 || head_size_ == 128) && + "`PagedAttentionInfinilm` supports head sizes 64 and 128"); + assert(num_heads_ % num_kv_heads_ == 0); + assert(q.stride(2) == 1 && out.stride(2) == 1); + assert(k_cache.stride(3) == 1 && v_cache.stride(3) == 1); + assert(block_tables.stride(1) == 1 && seq_lens.stride(0) == 1); + assert(!alibi_slopes.has_value() || + (alibi_slopes->dtype() == DataType::kFloat32 && + alibi_slopes->ndim() == 1 && alibi_slopes->size(0) == num_heads_ && + alibi_slopes->stride(0) == 1)); + } + + virtual void operator()(const Tensor q, const Tensor k_cache, + const Tensor v_cache, const Tensor block_tables, + const Tensor seq_lens, + std::optional alibi_slopes, float scale, + Tensor out) const = 0; + + protected: + static bool IsIndexDtype(DataType dtype) { + return dtype == DataType::kInt32 || dtype == DataType::kInt64 || + dtype == DataType::kUInt32; + } + + DataType dtype_; + DataType index_dtype_; + float scale_{1.0f}; + std::size_t num_seqs_{0}; + std::size_t num_heads_{0}; + std::size_t num_kv_heads_{0}; + std::size_t head_size_{0}; + std::size_t block_size_{0}; + std::size_t max_num_blocks_per_seq_{0}; + Tensor::Stride q_stride_{0}; + Tensor::Stride q_head_stride_{0}; + Tensor::Stride k_cache_block_stride_{0}; + Tensor::Stride k_cache_head_stride_{0}; + Tensor::Stride k_cache_slot_stride_{0}; + Tensor::Stride v_cache_block_stride_{0}; + Tensor::Stride v_cache_head_stride_{0}; + Tensor::Stride v_cache_slot_stride_{0}; + Tensor::Stride out_stride_{0}; + Tensor::Stride out_head_stride_{0}; + Tensor::Stride block_table_batch_stride_{0}; + Tensor::Stride seq_lens_stride_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/iluvatar/ops/paged_attention_infinilm/kernel.h b/src/native/cuda/iluvatar/ops/paged_attention_infinilm/kernel.h new file mode 100644 index 000000000..3284d3256 --- /dev/null +++ b/src/native/cuda/iluvatar/ops/paged_attention_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_ILUVATAR_PAGED_ATTENTION_INFINILM_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_PAGED_ATTENTION_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/iluvatar/caster.cuh" +#include "native/cuda/iluvatar/runtime_.h" +#include "native/cuda/ops/paged_attention_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionInfinilm> { + public: + using CudaPagedAttentionInfinilm< + Runtime>::CudaPagedAttentionInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/metax/ops/paged_attention_infinilm/kernel.h b/src/native/cuda/metax/ops/paged_attention_infinilm/kernel.h new file mode 100644 index 000000000..59287b74e --- /dev/null +++ b/src/native/cuda/metax/ops/paged_attention_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_METAX_PAGED_ATTENTION_INFINILM_KERNEL_H_ +#define INFINI_OPS_METAX_PAGED_ATTENTION_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/metax/caster.cuh" +#include "native/cuda/metax/runtime_.h" +#include "native/cuda/ops/paged_attention_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionInfinilm> { + public: + using CudaPagedAttentionInfinilm< + Runtime>::CudaPagedAttentionInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/moore/ops/paged_attention_infinilm/kernel.h b/src/native/cuda/moore/ops/paged_attention_infinilm/kernel.h new file mode 100644 index 000000000..7ee03230e --- /dev/null +++ b/src/native/cuda/moore/ops/paged_attention_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_MOORE_PAGED_ATTENTION_INFINILM_KERNEL_H_ +#define INFINI_OPS_MOORE_PAGED_ATTENTION_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/moore/caster.cuh" +#include "native/cuda/moore/runtime_.h" +#include "native/cuda/ops/paged_attention_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionInfinilm> { + public: + using CudaPagedAttentionInfinilm< + Runtime>::CudaPagedAttentionInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/nvidia/ops/paged_attention_infinilm/kernel.h b/src/native/cuda/nvidia/ops/paged_attention_infinilm/kernel.h new file mode 100644 index 000000000..08286ffdf --- /dev/null +++ b/src/native/cuda/nvidia/ops/paged_attention_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_NVIDIA_PAGED_ATTENTION_INFINILM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_PAGED_ATTENTION_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/paged_attention_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionInfinilm> { + public: + using CudaPagedAttentionInfinilm< + Runtime>::CudaPagedAttentionInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/paged_attention_infinilm/kernel.cuh b/src/native/cuda/ops/paged_attention_infinilm/kernel.cuh new file mode 100644 index 000000000..976a092a4 --- /dev/null +++ b/src/native/cuda/ops/paged_attention_infinilm/kernel.cuh @@ -0,0 +1,147 @@ +#ifndef INFINI_OPS_CUDA_PAGED_ATTENTION_INFINILM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_PAGED_ATTENTION_INFINILM_KERNEL_CUH_ + +#include +#include +#include +#include + +namespace infini::ops { + +template +__global__ void PagedAttentionInfinilmDecodeWarpKernel( + TData* __restrict__ out, const TData* __restrict__ q, + const TData* __restrict__ k_cache, const TData* __restrict__ v_cache, + const TIndex* __restrict__ block_tables, + const TIndex* __restrict__ seq_lens, const float* __restrict__ alibi_slopes, + std::size_t num_heads, std::size_t num_kv_heads, float scale, + std::size_t max_num_blocks_per_seq, std::size_t block_size, + std::ptrdiff_t k_cache_block_stride, std::ptrdiff_t k_cache_head_stride, + std::ptrdiff_t k_cache_slot_stride, std::ptrdiff_t v_cache_block_stride, + std::ptrdiff_t v_cache_head_stride, std::ptrdiff_t v_cache_slot_stride, + std::ptrdiff_t q_stride, std::ptrdiff_t q_head_stride, + std::ptrdiff_t out_stride, std::ptrdiff_t out_head_stride, + std::ptrdiff_t block_table_batch_stride, std::ptrdiff_t seq_lens_stride) { + constexpr int kWarpSize = 32; + static_assert(kHeadSize == 64 || kHeadSize == 128, + "PagedAttentionInfinilm decode supports head sizes 64 and 128"); + static_assert(kHeadSize % kWarpSize == 0, + "head size must be divisible by 32"); + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int lane = threadIdx.x; + constexpr int kDimsPerThread = kHeadSize / kWarpSize; + constexpr float kLog2e = 1.4426950408889634f; + + __shared__ float reduce_buf[kWarpSize]; + __shared__ float state_buf[2]; + + const int seq_len = static_cast(seq_lens[seq_idx * seq_lens_stride]); + TData* out_ptr = out + seq_idx * out_stride + head_idx * out_head_stride; + if (seq_len <= 0) { +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + out_ptr[lane * kDimsPerThread + i] = static_cast(0.0f); + } + return; + } + + const int queries_per_kv = static_cast(num_heads / num_kv_heads); + const int kv_head_idx = head_idx / queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.0f : alibi_slopes[head_idx]; + const float scale_log2 = scale * kLog2e; + const TIndex* block_table = block_tables + seq_idx * block_table_batch_stride; + const TData* q_ptr = q + seq_idx * q_stride + head_idx * q_head_stride; + + float q_reg[kDimsPerThread]; + float acc[kDimsPerThread]; +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + + float m = -FLT_MAX; + float l = 0.0f; + const int page_block_size = static_cast(block_size); + int t_base = 0; + for (int logical_block = 0; + t_base < seq_len && + logical_block < static_cast(max_num_blocks_per_seq); + ++logical_block, t_base += page_block_size) { + const int physical_block = static_cast(block_table[logical_block]); + const TData* k_base = k_cache + physical_block * k_cache_block_stride + + kv_head_idx * k_cache_head_stride; + const TData* v_base = v_cache + physical_block * v_cache_block_stride + + kv_head_idx * v_cache_head_stride; + const int token_end = min(page_block_size, seq_len - t_base); + + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int token_idx = t_base + token_in_block; + const TData* k_ptr = k_base + token_in_block * k_cache_slot_stride; + const TData* v_ptr = v_base + token_in_block * v_cache_slot_stride; + + float qk = 0.0f; +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + + reduce_buf[lane] = qk; + __syncthreads(); + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + if (lane < offset) { + reduce_buf[lane] += reduce_buf[lane + offset]; + } + __syncthreads(); + } + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = reduce_buf[0] * scale_log2; + if (alibi_slope != 0.0f) { + score += + (alibi_slope * static_cast(token_idx - (seq_len - 1))) * + kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + state_buf[0] = alpha; + state_buf[1] = beta; + } + __syncthreads(); + alpha = state_buf[0]; + beta = state_buf[1]; + +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + acc[i] = acc[i] * alpha + beta * static_cast(v_ptr[dim]); + } + } + } + + if (lane == 0) { + state_buf[0] = 1.0f / (l + 1e-6f); + } + __syncthreads(); + const float inv_l = state_buf[0]; + +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + out_ptr[dim] = static_cast(acc[i] * inv_l); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/paged_attention_infinilm/kernel.h b/src/native/cuda/ops/paged_attention_infinilm/kernel.h new file mode 100644 index 000000000..824b1be3d --- /dev/null +++ b/src/native/cuda/ops/paged_attention_infinilm/kernel.h @@ -0,0 +1,78 @@ +#ifndef INFINI_OPS_CUDA_PAGED_ATTENTION_INFINILM_KERNEL_H_ +#define INFINI_OPS_CUDA_PAGED_ATTENTION_INFINILM_KERNEL_H_ + +#include +#include +#include + +#include "base/paged_attention_infinilm.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/paged_attention_infinilm/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +using PagedAttentionInfinilmIndexTypes = + List; + +template +class CudaPagedAttentionInfinilm : public PagedAttentionInfinilm { + public: + using PagedAttentionInfinilm::PagedAttentionInfinilm; + + void operator()(const Tensor q, const Tensor k_cache, const Tensor v_cache, + const Tensor block_tables, const Tensor seq_lens, + std::optional alibi_slopes, float scale, + Tensor out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + assert(out.dtype() == dtype_ && q.dtype() == dtype_); + assert(k_cache.dtype() == dtype_ && v_cache.dtype() == dtype_); + assert(block_tables.dtype() == index_dtype_ && + seq_lens.dtype() == index_dtype_); + assert(scale == scale_); + + assert((head_size_ == 64 || head_size_ == 128) && + "PagedAttentionInfinilm supports head sizes 64 and 128"); + + dim3 grid(static_cast(num_heads_), + static_cast(num_seqs_)); + + DispatchFunc>( + {static_cast(dtype_), static_cast(index_dtype_), + static_cast(head_size_)}, + [&](auto list_tag) { + using TData = TypeMapType(list_tag)>; + using TIndex = + TypeMapType(list_tag)>; + constexpr int kHeadSize = ListGet<2>(list_tag); + + PagedAttentionInfinilmDecodeWarpKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(q.data()), + reinterpret_cast(k_cache.data()), + reinterpret_cast(v_cache.data()), + reinterpret_cast(block_tables.data()), + reinterpret_cast(seq_lens.data()), + alibi_slopes.has_value() + ? reinterpret_cast(alibi_slopes->data()) + : nullptr, + num_heads_, num_kv_heads_, scale, max_num_blocks_per_seq_, + block_size_, k_cache_block_stride_, k_cache_head_stride_, + k_cache_slot_stride_, v_cache_block_stride_, + v_cache_head_stride_, v_cache_slot_stride_, q_stride_, + q_head_stride_, out_stride_, out_head_stride_, + block_table_batch_stride_, seq_lens_stride_); + }, + "CudaPagedAttentionInfinilm::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_paged_attention_infinilm.py b/tests/test_paged_attention_infinilm.py new file mode 100644 index 000000000..e4a76c9ec --- /dev/null +++ b/tests/test_paged_attention_infinilm.py @@ -0,0 +1,180 @@ +import math + +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream + + +def get_alibi_slopes(n): + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))) + powers = [base**i for i in range(1, closest_power_of_2 + 1)] + if n > closest_power_of_2: + extra = [base ** (i * 2) for i in range(1, 2 * (n - closest_power_of_2) + 1, 2)] + powers += extra + return powers[:n] + + +def ref_paged_attention_infinilm( + q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale +): + output = torch.empty_like(q) + num_heads = q.shape[1] + num_kv_heads = k_cache.shape[1] + queries_per_kv = num_heads // num_kv_heads + block_size = k_cache.shape[2] + + for seq_id in range(q.shape[0]): + seq_len = seq_lens[seq_id].item() + table = block_tables[seq_id] + keys = [] + values = [] + for token_idx in range(seq_len): + block_id = table[token_idx // block_size].item() + block_offset = token_idx % block_size + keys.append(k_cache[block_id, :, block_offset, :]) + values.append(v_cache[block_id, :, block_offset, :]) + + k = torch.stack(keys, dim=0) + v = torch.stack(values, dim=0) + if queries_per_kv > 1: + k = torch.repeat_interleave(k, queries_per_kv, dim=1) + v = torch.repeat_interleave(v, queries_per_kv, dim=1) + + scores = torch.einsum("hd,khd->hk", q[seq_id], k).float() * scale + if alibi_slopes is not None: + pos = torch.arange(seq_len, device=q.device, dtype=torch.float32) + scores = scores + alibi_slopes.view(-1, 1) * (pos - seq_len + 1) + + weights = torch.softmax(scores, dim=-1).to(q.dtype) + output[seq_id] = torch.einsum("hk,khd->hd", weights, v) + + return output + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + ( + "num_seqs", + "num_heads", + "num_kv_heads", + "head_size", + "block_size", + "max_seq_len", + "use_alibi", + ), + ( + (1, 1, 1, 128, 16, 1024, False), + (4, 40, 40, 128, 16, 1024, False), + (6, 40, 40, 128, 16, 1024, False), + (3, 8, 8, 128, 16, 1024, False), + (3, 8, 8, 64, 16, 1024, False), + (8, 64, 8, 128, 16, 2048, False), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-2, 1e-3), + (torch.bfloat16, 5e-2, 5e-3), + ), +) +def test_paged_attention_infinilm( + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_seq_len, + use_alibi, + implementation_index, + dtype, + device, + rtol, + atol, +): + scale = head_size**-0.5 + max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + num_blocks = num_seqs * max_blocks_per_seq + + q = torch.randn((num_seqs, num_heads, head_size), dtype=dtype, device=device) + out = torch.empty_like(q) + k_cache = torch.randn( + (num_blocks, num_kv_heads, block_size, head_size), dtype=dtype, device=device + ) + v_cache = torch.randn_like(k_cache) + seq_lens = torch.randint( + 1, max_seq_len, (num_seqs,), dtype=torch.int64, device=device + ) + block_tables = torch.arange(num_blocks, dtype=torch.int64, device=device).view( + num_seqs, max_blocks_per_seq + ) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.tensor( + get_alibi_slopes(num_heads), dtype=torch.float32, device=device + ) + + return Payload( + lambda *args, **kwargs: _paged_attention_infinilm( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_paged_attention_infinilm, + (q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes), + {"scale": scale, "out": out}, + rtol=rtol, + atol=atol, + ) + + +def _paged_attention_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + alibi_slopes, + *, + scale, + out=None, + implementation_index=0, +): + infini.ops.paged_attention_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + alibi_slopes, + scale, + out, + implementation_index=implementation_index, + stream=get_stream(q.device), + ) + + return out + + +def _torch_paged_attention_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + alibi_slopes, + *, + scale, + out=None, +): + result = ref_paged_attention_infinilm( + q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale + ) + + if out is not None: + out.copy_(result) + else: + out = result + + return out