diff --git a/src/base/softmax_infinilm.h b/src/base/softmax_infinilm.h new file mode 100644 index 000000000..e4684c637 --- /dev/null +++ b/src/base/softmax_infinilm.h @@ -0,0 +1,70 @@ +#ifndef INFINI_OPS_BASE_SOFTMAX_INFINILM_H_ +#define INFINI_OPS_BASE_SOFTMAX_INFINILM_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class SoftmaxInfinilm : public Operator { + public: + SoftmaxInfinilm(const Tensor input, const int64_t dim, + const std::optional dtype, Tensor out) + : input_shape_{input.shape()}, + input_strides_{input.strides()}, + input_type_{input.dtype()}, + out_shape_{out.shape()}, + out_strides_{out.strides()}, + out_type_{out.dtype()}, + dim_{dim < 0 ? dim + static_cast(input.ndim()) : dim}, + dtype_{dtype}, + ndim_{out.ndim()}, + dim_size_{out.size(dim_)}, + row_count_{out.numel() / dim_size_}, + device_index_{out.device().index()} { + assert(input_shape_ == out_shape_ && + "`SoftmaxInfinilm` input and output shapes must match"); + assert(dim_ >= 0 && dim_ < static_cast(ndim_) && + "`SoftmaxInfinilm` dim out of range"); + assert(!dtype_.has_value() || dtype_.value() == out_type_); + assert(input_type_ == out_type_ && + "`SoftmaxInfinilm` input and output dtypes must match"); + assert(!out.HasBroadcastDim() && + "`SoftmaxInfinilm` output must not have broadcasted dimensions"); + } + + virtual void operator()(const Tensor input, const int64_t dim, + const std::optional dtype, + Tensor out) const = 0; + + protected: + Tensor::Shape input_shape_; + + Tensor::Strides input_strides_; + + DataType input_type_; + + Tensor::Shape out_shape_; + + Tensor::Strides out_strides_; + + DataType out_type_; + + int64_t dim_{}; + + std::optional dtype_{}; + + Tensor::Size ndim_{0}; + + Tensor::Size dim_size_{0}; + + Tensor::Size row_count_{0}; + + int device_index_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/iluvatar/ops/softmax_infinilm/kernel.h b/src/native/cuda/iluvatar/ops/softmax_infinilm/kernel.h new file mode 100644 index 000000000..b36ff0638 --- /dev/null +++ b/src/native/cuda/iluvatar/ops/softmax_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_ILUVATAR_SOFTMAX_INFINILM_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_SOFTMAX_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/iluvatar/caster.cuh" +#include "native/cuda/iluvatar/runtime_.h" +#include "native/cuda/ops/softmax_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaSoftmaxInfinilm> { + public: + using CudaSoftmaxInfinilm< + Runtime>::CudaSoftmaxInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/metax/ops/softmax_infinilm/kernel.h b/src/native/cuda/metax/ops/softmax_infinilm/kernel.h new file mode 100644 index 000000000..fe0a9da51 --- /dev/null +++ b/src/native/cuda/metax/ops/softmax_infinilm/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_METAX_SOFTMAX_INFINILM_KERNEL_H_ +#define INFINI_OPS_METAX_SOFTMAX_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/metax/caster.cuh" +#include "native/cuda/metax/runtime_.h" +#include "native/cuda/ops/softmax_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaSoftmaxInfinilm> { + public: + using CudaSoftmaxInfinilm>::CudaSoftmaxInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/moore/ops/softmax_infinilm/kernel.h b/src/native/cuda/moore/ops/softmax_infinilm/kernel.h new file mode 100644 index 000000000..a1d91f9b8 --- /dev/null +++ b/src/native/cuda/moore/ops/softmax_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_MOORE_SOFTMAX_INFINILM_KERNEL_H_ +#define INFINI_OPS_MOORE_SOFTMAX_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/moore/caster.cuh" +#include "native/cuda/moore/polyfills.cuh" +#include "native/cuda/moore/runtime_.h" +#include "native/cuda/ops/softmax_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaSoftmaxInfinilm> { + public: + using CudaSoftmaxInfinilm>::CudaSoftmaxInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/nvidia/ops/softmax_infinilm/kernel.h b/src/native/cuda/nvidia/ops/softmax_infinilm/kernel.h new file mode 100644 index 000000000..d00d143c4 --- /dev/null +++ b/src/native/cuda/nvidia/ops/softmax_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_NVIDIA_SOFTMAX_INFINILM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_SOFTMAX_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/softmax_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaSoftmaxInfinilm> { + public: + using CudaSoftmaxInfinilm< + Runtime>::CudaSoftmaxInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/softmax_infinilm/kernel.cuh b/src/native/cuda/ops/softmax_infinilm/kernel.cuh new file mode 100644 index 000000000..b2e177c60 --- /dev/null +++ b/src/native/cuda/ops/softmax_infinilm/kernel.cuh @@ -0,0 +1,113 @@ +#ifndef INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_CUH_ + +#include +#include +#include +#include + +#include "native/cuda/caster.cuh" +#include "native/cuda/kernel_commons.cuh" + +namespace infini::ops { + +namespace { + +struct SoftmaxInfinilmMaxOp { + __device__ __forceinline__ float operator()(float a, float b) const { + return a > b ? a : b; + } +}; + +template +__device__ __forceinline__ float BlockMax(float value) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + return BlockReduce(temp_storage).Reduce(value, SoftmaxInfinilmMaxOp()); +} + +template +__device__ __forceinline__ float BlockSum(float value) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + return BlockReduce(temp_storage).Sum(value); +} + +__device__ __forceinline__ size_t SoftmaxInfinilmRowOffset( + size_t row, size_t ndim, size_t dim, const size_t* __restrict__ shape, + const ptrdiff_t* __restrict__ strides) { + size_t offset = 0; + for (size_t axis = ndim; axis > 0; --axis) { + size_t i = axis - 1; + if (i == dim) { + continue; + } + size_t coord = row % shape[i]; + row /= shape[i]; + offset += coord * strides[i]; + } + return offset; +} + +} // namespace + +template +__global__ void SoftmaxInfinilmKernel( + T* __restrict__ out, const T* __restrict__ input, + const size_t* __restrict__ shape, const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ input_strides, size_t row_count, + size_t dim_size, size_t ndim, size_t dim) { + size_t row = blockIdx.x + blockIdx.y * gridDim.x; + if (row >= row_count) { + return; + } + + size_t input_base = + SoftmaxInfinilmRowOffset(row, ndim, dim, shape, input_strides); + size_t out_base = + SoftmaxInfinilmRowOffset(row, ndim, dim, shape, out_strides); + ptrdiff_t input_dim_stride = input_strides[dim]; + ptrdiff_t out_dim_stride = out_strides[dim]; + + float thread_max = -FLT_MAX; + for (size_t i = threadIdx.x; i < dim_size; i += block_size) { + float value = Caster::template Cast( + input[input_base + i * input_dim_stride]); + thread_max = thread_max > value ? thread_max : value; + } + + float block_max = BlockMax(thread_max); + __shared__ float max_value; + if (threadIdx.x == 0) { + max_value = block_max; + } + __syncthreads(); + + float thread_sum = 0.0f; + for (size_t i = threadIdx.x; i < dim_size; i += block_size) { + float value = Caster::template Cast( + input[input_base + i * input_dim_stride]); + float exp_value = expf(value - max_value); + thread_sum += exp_value; + out[out_base + i * out_dim_stride] = + Caster::template Cast(exp_value); + } + + float block_sum = BlockSum(thread_sum); + __shared__ float sum_value; + if (threadIdx.x == 0) { + sum_value = block_sum; + } + __syncthreads(); + + for (size_t i = threadIdx.x; i < dim_size; i += block_size) { + float value = + Caster::template Cast(out[out_base + i * out_dim_stride]); + out[out_base + i * out_dim_stride] = + Caster::template Cast(value / sum_value); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/softmax_infinilm/kernel.h b/src/native/cuda/ops/softmax_infinilm/kernel.h new file mode 100644 index 000000000..e9856f3f1 --- /dev/null +++ b/src/native/cuda/ops/softmax_infinilm/kernel.h @@ -0,0 +1,93 @@ +#ifndef INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_H_ +#define INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_H_ + +#include +#include +#include +#include + +#include "base/softmax_infinilm.h" +#include "common/generic_utils.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/softmax_infinilm/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaSoftmaxInfinilm : public SoftmaxInfinilm { + public: + CudaSoftmaxInfinilm(const Tensor input, const int64_t dim, + const std::optional dtype, Tensor out) + : SoftmaxInfinilm{input, dim, dtype, out} { + size_t shape_size = ndim_ * sizeof(*d_shape_); + size_t strides_size = ndim_ * sizeof(*d_input_strides_); + const size_t metadata_size = shape_size + 2 * strides_size; + std::vector metadata(metadata_size); + + Backend::Malloc((void**)&d_metadata_, metadata_size); + + size_t offset = 0; + d_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size); + offset += shape_size; + + d_input_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size); + offset += strides_size; + + d_out_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size); + + Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, + Backend::MemcpyHostToDevice); + } + + ~CudaSoftmaxInfinilm() { Backend::Free(d_metadata_); } + + void operator()(const Tensor input, const int64_t dim, + const std::optional dtype, + Tensor out) const override { + (void)dim; + (void)dtype; + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + int block_size = std::min( + RuntimeUtils::GetOptimalBlockSize(), 1024); + + DispatchFunc>( + {static_cast(out_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + const unsigned grid_x = + static_cast(std::min(row_count_, 65535)); + const unsigned grid_y = static_cast( + utils::CeilDiv(row_count_, static_cast(grid_x))); + + SoftmaxInfinilmKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), d_shape_, + d_out_strides_, d_input_strides_, row_count_, dim_size_, + ndim_, dim_); + }, + "CudaSoftmaxInfinilm::operator()"); + } + + private: + std::byte* d_metadata_{nullptr}; + + Tensor::Size* d_shape_{nullptr}; + + Tensor::Stride* d_input_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_softmax_infinilm.py b/tests/test_softmax_infinilm.py new file mode 100644 index 000000000..d2ff6b94c --- /dev/null +++ b/tests/test_softmax_infinilm.py @@ -0,0 +1,62 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, dim, inplace", + ( + ((4, 4), 0, True), + ((4, 4), 0, False), + ((12, 16, 512, 512), 0, True), + ((12, 16, 512, 512), 0, False), + ((12, 16, 512, 512), 1, True), + ((12, 16, 512, 512), 1, False), + ((12, 16, 512, 512), 2, True), + ((12, 16, 512, 512), 2, False), + ((12, 16, 512, 512), 3, True), + ((12, 16, 512, 512), 3, False), + ((1, 16, 512, 512), 0, True), + ((1, 16, 512, 512), 0, False), + ((1, 16, 512, 512), 1, True), + ((1, 16, 512, 512), 1, False), + ((1, 16, 512, 512), 2, True), + ((1, 16, 512, 512), 2, False), + ((1, 16, 512, 512), 3, True), + ((1, 16, 512, 512), 3, False), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 3e-5), + (torch.float16, 1e-2, 1e-3), + ), +) +def test_softmax_infinilm(shape, dim, inplace, dtype, device, rtol, atol): + input = randn_strided(shape, None, dtype=dtype, device=device) + out = input if inplace else torch.empty_like(input) + + return Payload( + lambda *args: _softmax_infinilm(*args, dim=dim), + lambda *args: _torch_softmax_infinilm(*args, dim=dim), + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _softmax_infinilm(input, out, dim): + infini.ops.softmax_infinilm(input, dim, None, out, stream=get_stream(input.device)) + + return out + + +def _torch_softmax_infinilm(input, out, dim): + out.copy_(torch.softmax(input, dim=dim)) + + return out