diff --git a/src/base/paged_attention_prefill_infinilm.h b/src/base/paged_attention_prefill_infinilm.h new file mode 100644 index 000000000..07091b6e4 --- /dev/null +++ b/src/base/paged_attention_prefill_infinilm.h @@ -0,0 +1,129 @@ +#ifndef INFINI_OPS_BASE_PAGED_ATTENTION_PREFILL_INFINILM_H_ +#define INFINI_OPS_BASE_PAGED_ATTENTION_PREFILL_INFINILM_H_ + +#include +#include +#include + +#include "data_type.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class PagedAttentionPrefillInfinilm + : public Operator { + public: + PagedAttentionPrefillInfinilm(const Tensor q, const Tensor k_cache, + const Tensor v_cache, const Tensor block_tables, + const Tensor seq_lens, + const Tensor cum_seq_lens_q, + std::optional alibi_slopes, float scale, + Tensor out) + : dtype_{q.dtype()}, + index_dtype_{block_tables.dtype()}, + scale_{scale}, + num_seqs_{seq_lens.size(0)}, + total_q_tokens_{q.size(0)}, + num_heads_{q.size(1)}, + num_kv_heads_{k_cache.size(1)}, + head_size_{q.size(2)}, + block_size_{k_cache.size(2)}, + max_num_blocks_per_seq_{block_tables.size(1)}, + q_stride_{q.stride(0)}, + q_head_stride_{q.stride(1)}, + k_cache_block_stride_{k_cache.stride(0)}, + k_cache_head_stride_{k_cache.stride(1)}, + k_cache_slot_stride_{k_cache.stride(2)}, + v_cache_block_stride_{v_cache.stride(0)}, + v_cache_head_stride_{v_cache.stride(1)}, + v_cache_slot_stride_{v_cache.stride(2)}, + out_stride_{out.stride(0)}, + out_head_stride_{out.stride(1)}, + block_table_batch_stride_{block_tables.stride(0)} { + assert(q.ndim() == 3 && out.ndim() == 3); + assert(k_cache.ndim() == 4 && v_cache.ndim() == 4); + assert(block_tables.ndim() == 2 && seq_lens.ndim() == 1 && + cum_seq_lens_q.ndim() == 1); + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16) && + "`PagedAttentionPrefillInfinilm` supports float16 and bfloat16"); + assert(out.dtype() == dtype_ && k_cache.dtype() == dtype_ && + v_cache.dtype() == dtype_); + assert(IsIndexDtype(index_dtype_) && seq_lens.dtype() == index_dtype_ && + cum_seq_lens_q.dtype() == index_dtype_); + assert(cum_seq_lens_q.size(0) == num_seqs_ + 1); + assert(q.shape() == out.shape()); + assert(k_cache.shape() == v_cache.shape()); + assert(block_tables.size(0) == num_seqs_); + assert(k_cache.size(1) == num_kv_heads_ && + v_cache.size(1) == num_kv_heads_); + assert(k_cache.size(3) == head_size_ && v_cache.size(3) == head_size_); + assert((head_size_ == 64 || head_size_ == 128) && + "`PagedAttentionPrefillInfinilm` supports head sizes 64 and 128"); + assert(num_heads_ % num_kv_heads_ == 0); + assert(q.stride(2) == 1 && out.stride(2) == 1); + assert(k_cache.stride(3) == 1 && v_cache.stride(3) == 1); + assert(!alibi_slopes.has_value() || + (alibi_slopes->dtype() == DataType::kFloat32 && + alibi_slopes->ndim() == 1 && alibi_slopes->size(0) == num_heads_ && + alibi_slopes->stride(0) == 1)); + } + + virtual void operator()(const Tensor q, const Tensor k_cache, + const Tensor v_cache, const Tensor block_tables, + const Tensor seq_lens, const Tensor cum_seq_lens_q, + std::optional alibi_slopes, float scale, + Tensor out) const = 0; + + protected: + static bool IsIndexDtype(DataType dtype) { + return dtype == DataType::kInt32 || dtype == DataType::kInt64 || + dtype == DataType::kUInt32; + } + + DataType dtype_; + + DataType index_dtype_; + + float scale_{1.0f}; + + std::size_t num_seqs_{0}; + + std::size_t total_q_tokens_{0}; + + std::size_t num_heads_{0}; + + std::size_t num_kv_heads_{0}; + + std::size_t head_size_{0}; + + std::size_t block_size_{0}; + + std::size_t max_num_blocks_per_seq_{0}; + + Tensor::Stride q_stride_{0}; + + Tensor::Stride q_head_stride_{0}; + + Tensor::Stride k_cache_block_stride_{0}; + + Tensor::Stride k_cache_head_stride_{0}; + + Tensor::Stride k_cache_slot_stride_{0}; + + Tensor::Stride v_cache_block_stride_{0}; + + Tensor::Stride v_cache_head_stride_{0}; + + Tensor::Stride v_cache_slot_stride_{0}; + + Tensor::Stride out_stride_{0}; + + Tensor::Stride out_head_stride_{0}; + + Tensor::Stride block_table_batch_stride_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/iluvatar/ops/paged_attention_prefill_infinilm/kernel.h b/src/native/cuda/iluvatar/ops/paged_attention_prefill_infinilm/kernel.h new file mode 100644 index 000000000..d2ad9806e --- /dev/null +++ b/src/native/cuda/iluvatar/ops/paged_attention_prefill_infinilm/kernel.h @@ -0,0 +1,23 @@ +#ifndef INFINI_OPS_ILUVATAR_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/iluvatar/caster.cuh" +#include "native/cuda/iluvatar/runtime_.h" +#include "native/cuda/ops/paged_attention_prefill_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionPrefillInfinilm< + Runtime> { + public: + using CudaPagedAttentionPrefillInfinilm< + Runtime>::CudaPagedAttentionPrefillInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/metax/ops/paged_attention_prefill_infinilm/kernel.h b/src/native/cuda/metax/ops/paged_attention_prefill_infinilm/kernel.h new file mode 100644 index 000000000..dedd3c053 --- /dev/null +++ b/src/native/cuda/metax/ops/paged_attention_prefill_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_METAX_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ +#define INFINI_OPS_METAX_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/metax/caster.cuh" +#include "native/cuda/metax/runtime_.h" +#include "native/cuda/ops/paged_attention_prefill_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionPrefillInfinilm> { + public: + using CudaPagedAttentionPrefillInfinilm< + Runtime>::CudaPagedAttentionPrefillInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/moore/ops/paged_attention_prefill_infinilm/kernel.h b/src/native/cuda/moore/ops/paged_attention_prefill_infinilm/kernel.h new file mode 100644 index 000000000..71c06b1f2 --- /dev/null +++ b/src/native/cuda/moore/ops/paged_attention_prefill_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_MOORE_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ +#define INFINI_OPS_MOORE_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/moore/caster.cuh" +#include "native/cuda/moore/runtime_.h" +#include "native/cuda/ops/paged_attention_prefill_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionPrefillInfinilm> { + public: + using CudaPagedAttentionPrefillInfinilm< + Runtime>::CudaPagedAttentionPrefillInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/nvidia/ops/paged_attention_prefill_infinilm/kernel.h b/src/native/cuda/nvidia/ops/paged_attention_prefill_infinilm/kernel.h new file mode 100644 index 000000000..809590e6b --- /dev/null +++ b/src/native/cuda/nvidia/ops/paged_attention_prefill_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_NVIDIA_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/paged_attention_prefill_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaPagedAttentionPrefillInfinilm> { + public: + using CudaPagedAttentionPrefillInfinilm< + Runtime>::CudaPagedAttentionPrefillInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/paged_attention_prefill_infinilm/kernel.cuh b/src/native/cuda/ops/paged_attention_prefill_infinilm/kernel.cuh new file mode 100644 index 000000000..43a1d7dca --- /dev/null +++ b/src/native/cuda/ops/paged_attention_prefill_infinilm/kernel.cuh @@ -0,0 +1,186 @@ +#ifndef INFINI_OPS_CUDA_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_CUH_ + +#include +#include +#include +#include + +namespace infini::ops { + +template +__device__ __forceinline__ std::size_t PagedPrefillFindSeqId( + std::size_t token_idx, const TIndex* cum_seq_lens_q, std::size_t num_seqs) { + std::size_t low = 0; + std::size_t high = num_seqs; + + while (low < high) { + std::size_t mid = (low + high) >> 1; + std::size_t begin = static_cast(cum_seq_lens_q[mid]); + std::size_t end = static_cast(cum_seq_lens_q[mid + 1]); + + if (token_idx >= begin && token_idx < end) { + return mid; + } + + if (token_idx < begin) { + high = mid; + } else { + low = mid + 1; + } + } + + return 0; +} + +template +__global__ void PagedAttentionPrefillInfinilmKernel( + TData* __restrict__ out, const TData* __restrict__ q, + const TData* __restrict__ k_cache, const TData* __restrict__ v_cache, + const TIndex* __restrict__ block_tables, + const TIndex* __restrict__ seq_lens, + const TIndex* __restrict__ cum_seq_lens_q, + const float* __restrict__ alibi_slopes, std::size_t num_heads, + std::size_t num_kv_heads, float scale, std::size_t max_num_blocks_per_seq, + std::size_t block_size, std::ptrdiff_t k_cache_block_stride, + std::ptrdiff_t k_cache_head_stride, std::ptrdiff_t k_cache_slot_stride, + std::ptrdiff_t v_cache_block_stride, std::ptrdiff_t v_cache_head_stride, + std::ptrdiff_t v_cache_slot_stride, std::ptrdiff_t q_stride, + std::ptrdiff_t q_head_stride, std::ptrdiff_t out_stride, + std::ptrdiff_t out_head_stride, std::ptrdiff_t block_table_batch_stride, + std::size_t num_seqs) { + constexpr int kWarpSize = 32; + static_assert(kHeadSize == 64 || kHeadSize == 128, + "PagedAttentionPrefillInfinilm supports head sizes 64 and 128"); + static_assert(kHeadSize % kWarpSize == 0, + "head size must be divisible by 32"); + + const std::size_t global_token_idx = blockIdx.x; + const std::size_t head_idx = blockIdx.y; + const int lane = threadIdx.x; + constexpr int kDimsPerThread = kHeadSize / kWarpSize; + constexpr float kLog2e = 1.4426950408889634f; + + __shared__ float reduce_buf[kWarpSize]; + __shared__ float state_buf[2]; + + const std::size_t seq_idx = + PagedPrefillFindSeqId(global_token_idx, cum_seq_lens_q, num_seqs); + const std::size_t q_begin = static_cast(cum_seq_lens_q[seq_idx]); + const std::size_t q_end = + static_cast(cum_seq_lens_q[seq_idx + 1]); + const int q_len = static_cast(q_end - q_begin); + const int q_token_local = static_cast(global_token_idx - q_begin); + if (q_token_local < 0 || q_token_local >= q_len) { + return; + } + + const int total_kv_len = static_cast(seq_lens[seq_idx]); + const int history_len = total_kv_len - q_len; + const int allowed_k_len = history_len + q_token_local + 1; + if (allowed_k_len <= 0) { + return; + } + + const int queries_per_kv = static_cast(num_heads / num_kv_heads); + const int kv_head_idx = static_cast(head_idx) / queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.0f : alibi_slopes[head_idx]; + const float scale_log2 = scale * kLog2e; + const TIndex* block_table = block_tables + seq_idx * block_table_batch_stride; + + const TData* q_ptr = + q + global_token_idx * q_stride + head_idx * q_head_stride; + TData* out_ptr = + out + global_token_idx * out_stride + head_idx * out_head_stride; + + float q_reg[kDimsPerThread]; + float acc[kDimsPerThread]; +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + + float m = -FLT_MAX; + float l = 0.0f; + const int page_block_size = static_cast(block_size); + int t_base = 0; + for (int logical_block = 0; + t_base < allowed_k_len && + logical_block < static_cast(max_num_blocks_per_seq); + ++logical_block, t_base += page_block_size) { + const int physical_block = static_cast(block_table[logical_block]); + const TData* k_base = k_cache + physical_block * k_cache_block_stride + + kv_head_idx * k_cache_head_stride; + const TData* v_base = v_cache + physical_block * v_cache_block_stride + + kv_head_idx * v_cache_head_stride; + const int token_end = min(page_block_size, allowed_k_len - t_base); + + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int token_idx = t_base + token_in_block; + const TData* k_ptr = k_base + token_in_block * k_cache_slot_stride; + const TData* v_ptr = v_base + token_in_block * v_cache_slot_stride; + + float qk = 0.0f; +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + + reduce_buf[lane] = qk; + __syncthreads(); + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + if (lane < offset) { + reduce_buf[lane] += reduce_buf[lane + offset]; + } + __syncthreads(); + } + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = reduce_buf[0] * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * + static_cast(token_idx - (allowed_k_len - 1))) * + kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + state_buf[0] = alpha; + state_buf[1] = beta; + } + __syncthreads(); + alpha = state_buf[0]; + beta = state_buf[1]; + +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + acc[i] = acc[i] * alpha + beta * static_cast(v_ptr[dim]); + } + } + } + + if (lane == 0) { + state_buf[0] = 1.0f / (l + 1e-6f); + } + __syncthreads(); + const float inv_l = state_buf[0]; + +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + out_ptr[dim] = static_cast(acc[i] * inv_l); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/paged_attention_prefill_infinilm/kernel.h b/src/native/cuda/ops/paged_attention_prefill_infinilm/kernel.h new file mode 100644 index 000000000..6bec289ec --- /dev/null +++ b/src/native/cuda/ops/paged_attention_prefill_infinilm/kernel.h @@ -0,0 +1,81 @@ +#ifndef INFINI_OPS_CUDA_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ +#define INFINI_OPS_CUDA_PAGED_ATTENTION_PREFILL_INFINILM_KERNEL_H_ + +#include +#include +#include + +#include "base/paged_attention_prefill_infinilm.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/paged_attention_prefill_infinilm/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +using PagedAttentionPrefillInfinilmIndexTypes = + List; + +template +class CudaPagedAttentionPrefillInfinilm : public PagedAttentionPrefillInfinilm { + public: + using PagedAttentionPrefillInfinilm::PagedAttentionPrefillInfinilm; + + void operator()(const Tensor q, const Tensor k_cache, const Tensor v_cache, + const Tensor block_tables, const Tensor seq_lens, + const Tensor cum_seq_lens_q, + std::optional alibi_slopes, float scale, + Tensor out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + assert(out.dtype() == dtype_ && q.dtype() == dtype_); + assert(k_cache.dtype() == dtype_ && v_cache.dtype() == dtype_); + assert(block_tables.dtype() == index_dtype_ && + seq_lens.dtype() == index_dtype_ && + cum_seq_lens_q.dtype() == index_dtype_); + assert(scale == scale_); + + assert((head_size_ == 64 || head_size_ == 128) && + "PagedAttentionPrefillInfinilm supports head sizes 64 and 128"); + + dim3 grid(static_cast(total_q_tokens_), + static_cast(num_heads_)); + + DispatchFunc>( + {static_cast(dtype_), static_cast(index_dtype_), + static_cast(head_size_)}, + [&](auto list_tag) { + using TData = TypeMapType(list_tag)>; + using TIndex = + TypeMapType(list_tag)>; + constexpr int kHeadSize = ListGet<2>(list_tag); + + PagedAttentionPrefillInfinilmKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(q.data()), + reinterpret_cast(k_cache.data()), + reinterpret_cast(v_cache.data()), + reinterpret_cast(block_tables.data()), + reinterpret_cast(seq_lens.data()), + reinterpret_cast(cum_seq_lens_q.data()), + alibi_slopes.has_value() + ? reinterpret_cast(alibi_slopes->data()) + : nullptr, + num_heads_, num_kv_heads_, scale, max_num_blocks_per_seq_, + block_size_, k_cache_block_stride_, k_cache_head_stride_, + k_cache_slot_stride_, v_cache_block_stride_, + v_cache_head_stride_, v_cache_slot_stride_, q_stride_, + q_head_stride_, out_stride_, out_head_stride_, + block_table_batch_stride_, num_seqs_); + }, + "CudaPagedAttentionPrefillInfinilm::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_paged_attention_prefill_infinilm.py b/tests/test_paged_attention_prefill_infinilm.py new file mode 100644 index 000000000..5fc687925 --- /dev/null +++ b/tests/test_paged_attention_prefill_infinilm.py @@ -0,0 +1,248 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_stream + + +class SimpleCacheManager: + def __init__(self, num_blocks, block_size): + self.block_size = block_size + self.free_blocks = list(range(num_blocks)) + self.request_to_blocks = {} + self.request_to_len = {} + + def allocate_slots(self, request_id, num_new_tokens): + if request_id not in self.request_to_len: + self.request_to_len[request_id] = 0 + self.request_to_blocks[request_id] = [] + + start_pos = self.request_to_len[request_id] + new_total_len = start_pos + num_new_tokens + needed_blocks = (new_total_len + self.block_size - 1) // self.block_size + added_blocks = needed_blocks - len(self.request_to_blocks[request_id]) + + for _ in range(added_blocks): + self.request_to_blocks[request_id].append(self.free_blocks.pop(0)) + + self.request_to_len[request_id] = new_total_len + return self.request_to_blocks[request_id], new_total_len + + +def ref_paged_attention_prefill_infinilm( + q, k_cache, v_cache, block_tables, seq_lens, cum_seq_lens_q, scale +): + block_size = k_cache.shape[2] + outputs = torch.zeros_like(q) + num_seqs = cum_seq_lens_q.numel() - 1 + + for seq_id in range(num_seqs): + q_begin = cum_seq_lens_q[seq_id].item() + q_end = cum_seq_lens_q[seq_id + 1].item() + num_new = q_end - q_begin + total_len = seq_lens[seq_id].item() + history_len = total_len - num_new + + table = block_tables[seq_id] + keys = [] + values = [] + for pos in range(total_len): + block_id = table[pos // block_size].item() + block_offset = pos % block_size + keys.append(k_cache[block_id, :, block_offset, :]) + values.append(v_cache[block_id, :, block_offset, :]) + + k = torch.stack(keys, dim=0) + v = torch.stack(values, dim=0) + q_seq = q[q_begin:q_end] + + scores = torch.einsum("qhd,khd->hqk", q_seq, k).float() * scale + mask = torch.full((num_new, total_len), float("-inf"), device=q.device) + for q_idx in range(num_new): + mask[q_idx, : history_len + q_idx + 1] = 0.0 + + weights = torch.softmax(scores + mask.unsqueeze(0), dim=-1).to(q.dtype) + outputs[q_begin:q_end] = torch.einsum("hqk,khd->qhd", weights, v) + + return outputs + + +@pytest.mark.parametrize( + ( + "num_seqs", + "num_heads", + "num_kv_heads", + "head_size", + "block_size", + "max_step_len", + "num_rounds", + "index_dtype", + ), + ( + (1, 1, 1, 128, 8, 16, 1, torch.int32), + (1, 1, 1, 128, 8, 16, 1, torch.int64), + (1, 4, 4, 128, 8, 16, 4, torch.int32), + (1, 4, 4, 128, 8, 16, 4, torch.int64), + (2, 8, 8, 128, 16, 32, 2, torch.int32), + (2, 8, 8, 128, 16, 32, 2, torch.int64), + (4, 16, 16, 128, 8, 64, 3, torch.int32), + (4, 16, 16, 128, 8, 64, 3, torch.int64), + (8, 64, 64, 128, 8, 16, 5, torch.int32), + (8, 64, 64, 128, 8, 16, 5, torch.int64), + (16, 128, 128, 128, 8, 16, 4, torch.int32), + (16, 128, 128, 128, 8, 16, 4, torch.int64), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 2e-2, 2e-2), + ), +) +def test_paged_attention_prefill_infinilm( + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_step_len, + num_rounds, + index_dtype, + implementation_index, + dtype, + device, + rtol, + atol, +): + max_tokens = num_seqs * max_step_len * num_rounds + num_blocks = (max_tokens + block_size - 1) // block_size + num_seqs + 4 + manager = SimpleCacheManager(num_blocks, block_size) + scale = head_size**-0.5 + + k_cache = torch.empty( + (num_blocks, num_kv_heads, block_size, head_size), + dtype=dtype, + device=device, + ) + v_cache = torch.empty_like(k_cache) + + for _ in range(num_rounds): + query_lens = torch.randint(1, max_step_len + 1, (num_seqs,)) + q_total_tokens = query_lens.sum().item() + q = torch.empty( + (q_total_tokens, num_heads, head_size), dtype=dtype, device=device + ) + + seq_lens_list = [] + block_tables_list = [] + cum_seq_lens_q = [0] + + for seq_id in range(num_seqs): + cur_q_len = query_lens[seq_id].item() + table, total_len = manager.allocate_slots(seq_id, cur_q_len) + history_len = total_len - cur_q_len + seq_lens_list.append(total_len) + block_tables_list.append(table) + + k_new = torch.randn( + (cur_q_len, num_kv_heads, head_size), dtype=dtype, device=device + ) + v_new = torch.randn_like(k_new) + q_new = torch.randn( + (cur_q_len, num_heads, head_size), dtype=dtype, device=device + ) + q_begin = cum_seq_lens_q[-1] + q[q_begin : q_begin + cur_q_len] = q_new + + for token_idx in range(cur_q_len): + logical_pos = history_len + token_idx + block_id = table[logical_pos // block_size] + block_offset = logical_pos % block_size + k_cache[block_id, :, block_offset, :] = k_new[token_idx] + v_cache[block_id, :, block_offset, :] = v_new[token_idx] + + cum_seq_lens_q.append(q_begin + cur_q_len) + + max_blocks = max(len(table) for table in block_tables_list) + padded_tables = [ + table + [0] * (max_blocks - len(table)) for table in block_tables_list + ] + block_tables = torch.tensor(padded_tables, dtype=index_dtype, device=device) + seq_lens = torch.tensor(seq_lens_list, dtype=index_dtype, device=device) + cum_seq_lens_q = torch.tensor(cum_seq_lens_q, dtype=index_dtype, device=device) + out = torch.empty_like(q) + + actual = _paged_attention_prefill_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + cum_seq_lens_q, + scale=scale, + out=out, + implementation_index=implementation_index, + ) + expected = _torch_paged_attention_prefill_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + cum_seq_lens_q, + scale=scale, + ) + assert torch.allclose(actual, expected, rtol=rtol, atol=atol) + + +def _paged_attention_prefill_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + cum_seq_lens_q, + *, + scale, + out=None, + implementation_index=0, +): + infini.ops.paged_attention_prefill_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + cum_seq_lens_q, + None, + scale, + out, + implementation_index=implementation_index, + stream=get_stream(q.device), + ) + + return out + + +def _torch_paged_attention_prefill_infinilm( + q, + k_cache, + v_cache, + block_tables, + seq_lens, + cum_seq_lens_q, + *, + scale, + out=None, +): + result = ref_paged_attention_prefill_infinilm( + q, k_cache, v_cache, block_tables, seq_lens, cum_seq_lens_q, scale + ) + + if out is not None: + out.copy_(result) + else: + out = result + + return out