Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/base/topksoftmax_infinilm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#ifndef INFINI_OPS_BASE_TOPKSOFTMAX_INFINILM_H_
#define INFINI_OPS_BASE_TOPKSOFTMAX_INFINILM_H_

#include <cassert>

#include "operator.h"

namespace infini::ops {

class TopksoftmaxInfinilm : public Operator<TopksoftmaxInfinilm> {
public:
TopksoftmaxInfinilm(const Tensor input, const int64_t topk, const bool norm,
Tensor values, Tensor indices)
: input_shape_{input.shape()},
input_strides_{input.strides()},
input_type_{input.dtype()},
values_shape_{values.shape()},
values_strides_{values.strides()},
values_type_{values.dtype()},
indices_shape_{indices.shape()},
indices_strides_{indices.strides()},
indices_type_{indices.dtype()},
topk_{topk},
norm_{norm},
row_count_{input.size(0)},
width_{input.size(1)},
device_index_{values.device().index()} {
assert(input.ndim() == 2 &&
"`TopksoftmaxInfinilm` input must be a 2D tensor");
assert(topk_ > 0 && topk_ <= static_cast<int64_t>(width_) &&
"`TopksoftmaxInfinilm` topk must be in (0, input.size(1)]");
assert(values_shape_ == indices_shape_ &&
"`TopksoftmaxInfinilm` values and indices shapes must match");
assert(
values_shape_.size() == 2 && values_shape_[0] == row_count_ &&
values_shape_[1] == static_cast<Tensor::Size>(topk_) &&
"`TopksoftmaxInfinilm` outputs must have shape (input.size(0), topk)");
assert(values_type_ == DataType::kFloat32 &&
"`TopksoftmaxInfinilm` values output must be float32");
assert(indices_type_ == DataType::kInt32 &&
"`TopksoftmaxInfinilm` indices output must be int32");
assert((input_type_ == DataType::kFloat16 ||
input_type_ == DataType::kBFloat16 ||
input_type_ == DataType::kFloat32 ||
input_type_ == DataType::kFloat64) &&
"`TopksoftmaxInfinilm` input must be a floating point tensor");
assert(
!values.HasBroadcastDim() && !indices.HasBroadcastDim() &&
"`TopksoftmaxInfinilm` outputs must not have broadcasted dimensions");
}

virtual void operator()(const Tensor input, const int64_t topk,
const bool norm, Tensor values,
Tensor indices) const = 0;

protected:
Tensor::Shape input_shape_;

Tensor::Strides input_strides_;

DataType input_type_;

Tensor::Shape values_shape_;

Tensor::Strides values_strides_;

DataType values_type_;

Tensor::Shape indices_shape_;

Tensor::Strides indices_strides_;

DataType indices_type_;

int64_t topk_{0};

bool norm_{false};

Tensor::Size row_count_{0};

Tensor::Size width_{0};

int device_index_{0};
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/iluvatar/ops/topksoftmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_ILUVATAR_TOPKSOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_ILUVATAR_TOPKSOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/topksoftmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<TopksoftmaxInfinilm, Device::Type::kIluvatar>
: public CudaTopksoftmaxInfinilm<Runtime<Device::Type::kIluvatar>> {
public:
using CudaTopksoftmaxInfinilm<
Runtime<Device::Type::kIluvatar>>::CudaTopksoftmaxInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/metax/ops/topksoftmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_METAX_TOPKSOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_METAX_TOPKSOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/topksoftmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<TopksoftmaxInfinilm, Device::Type::kMetax>
: public CudaTopksoftmaxInfinilm<Runtime<Device::Type::kMetax>> {
public:
using CudaTopksoftmaxInfinilm<
Runtime<Device::Type::kMetax>>::CudaTopksoftmaxInfinilm;
};

} // namespace infini::ops

#endif
23 changes: 23 additions & 0 deletions src/native/cuda/moore/ops/topksoftmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef INFINI_OPS_MOORE_TOPKSOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_MOORE_TOPKSOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/polyfills.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/topksoftmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<TopksoftmaxInfinilm, Device::Type::kMoore>
: public CudaTopksoftmaxInfinilm<Runtime<Device::Type::kMoore>> {
public:
using CudaTopksoftmaxInfinilm<
Runtime<Device::Type::kMoore>>::CudaTopksoftmaxInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/nvidia/ops/topksoftmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_NVIDIA_TOPKSOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_NVIDIA_TOPKSOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/topksoftmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<TopksoftmaxInfinilm, Device::Type::kNvidia>
: public CudaTopksoftmaxInfinilm<Runtime<Device::Type::kNvidia>> {
public:
using CudaTopksoftmaxInfinilm<
Runtime<Device::Type::kNvidia>>::CudaTopksoftmaxInfinilm;
};

} // namespace infini::ops

#endif
169 changes: 169 additions & 0 deletions src/native/cuda/ops/topksoftmax_infinilm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#ifndef INFINI_OPS_CUDA_TOPKSOFTMAX_INFINILM_KERNEL_CUH_
#define INFINI_OPS_CUDA_TOPKSOFTMAX_INFINILM_KERNEL_CUH_

#include <cfloat>
#include <cmath>
#include <cstddef>

#include "native/cuda/caster.cuh"

namespace infini::ops {

namespace {

template <unsigned int block_size>
__device__ __forceinline__ float TopksoftmaxInfinilmBlockMax(float value) {
__shared__ float values[block_size];
values[threadIdx.x] = value;
__syncthreads();

for (unsigned int stride = block_size / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
float other = values[threadIdx.x + stride];
values[threadIdx.x] =
values[threadIdx.x] > other ? values[threadIdx.x] : other;
}
__syncthreads();
}

return values[0];
}

template <unsigned int block_size>
__device__ __forceinline__ float TopksoftmaxInfinilmBlockSum(float value) {
__shared__ float values[block_size];
values[threadIdx.x] = value;
__syncthreads();

for (unsigned int stride = block_size / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
values[threadIdx.x] += values[threadIdx.x + stride];
}
__syncthreads();
}

return values[0];
}

__device__ __forceinline__ bool TopksoftmaxInfinilmBetter(float value,
int index,
float best_value,
int best_index) {
return value > best_value || (value == best_value && index < best_index);
}

template <unsigned int block_size>
__device__ __forceinline__ void TopksoftmaxInfinilmBlockBest(float& value,
int& index) {
__shared__ float values[block_size];
__shared__ int indices[block_size];
values[threadIdx.x] = value;
indices[threadIdx.x] = index;
__syncthreads();

for (unsigned int stride = block_size / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
float other_value = values[threadIdx.x + stride];
int other_index = indices[threadIdx.x + stride];
if (TopksoftmaxInfinilmBetter(other_value, other_index,
values[threadIdx.x],
indices[threadIdx.x])) {
values[threadIdx.x] = other_value;
indices[threadIdx.x] = other_index;
}
}
__syncthreads();
}

value = values[0];
index = indices[0];
}

} // namespace

template <unsigned int block_size, Device::Type kDev, typename T>
__global__ void TopksoftmaxInfinilmKernel(
float* __restrict__ values, int32_t* __restrict__ indices,
const T* __restrict__ input, const ptrdiff_t* __restrict__ input_strides,
const ptrdiff_t* __restrict__ values_strides,
const ptrdiff_t* __restrict__ indices_strides, size_t row_count,
size_t width, size_t topk, bool norm) {
size_t row = blockIdx.x + blockIdx.y * gridDim.x;
if (row >= row_count) {
return;
}

ptrdiff_t input_base = row * input_strides[0];
ptrdiff_t values_base = row * values_strides[0];
ptrdiff_t indices_base = row * indices_strides[0];

float thread_max = -FLT_MAX;
for (size_t col = threadIdx.x; col < width; col += block_size) {
float value = Caster<kDev>::template Cast<float>(
input[input_base + col * input_strides[1]]);
thread_max = thread_max > value ? thread_max : value;
}

float max_value = TopksoftmaxInfinilmBlockMax<block_size>(thread_max);

float thread_sum = 0.0f;
for (size_t col = threadIdx.x; col < width; col += block_size) {
float value = Caster<kDev>::template Cast<float>(
input[input_base + col * input_strides[1]]);
thread_sum += expf(value - max_value);
}

float softmax_sum = TopksoftmaxInfinilmBlockSum<block_size>(thread_sum);

for (size_t rank = 0; rank < topk; ++rank) {
float thread_best = -FLT_MAX;
int thread_index = -1;

for (size_t col = threadIdx.x; col < width; col += block_size) {
bool selected = false;
for (size_t prev = 0; prev < rank; ++prev) {
if (indices[indices_base + prev * indices_strides[1]] ==
static_cast<int32_t>(col)) {
selected = true;
break;
}
}

if (!selected) {
float value = Caster<kDev>::template Cast<float>(
input[input_base + col * input_strides[1]]);
float softmax_value = expf(value - max_value) / softmax_sum;
if (TopksoftmaxInfinilmBetter(softmax_value, static_cast<int>(col),
thread_best, thread_index)) {
thread_best = softmax_value;
thread_index = static_cast<int>(col);
}
}
}

TopksoftmaxInfinilmBlockBest<block_size>(thread_best, thread_index);

if (threadIdx.x == 0) {
values[values_base + rank * values_strides[1]] = thread_best;
indices[indices_base + rank * indices_strides[1]] = thread_index;
}
__syncthreads();
}

if (norm) {
float thread_topk_sum = 0.0f;
for (size_t rank = threadIdx.x; rank < topk; rank += block_size) {
thread_topk_sum += values[values_base + rank * values_strides[1]];
}
float topk_sum = TopksoftmaxInfinilmBlockSum<block_size>(thread_topk_sum);

for (size_t rank = threadIdx.x; rank < topk; rank += block_size) {
ptrdiff_t offset = values_base + rank * values_strides[1];
values[offset] = values[offset] / topk_sum;
}
}
}

} // namespace infini::ops

#endif
Loading
Loading