From fa23860c1b9428758702b7c604b8bde8855dcd34 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 10 Jun 2026 14:52:24 +0800 Subject: [PATCH] feat: add cuda relu infinilm --- src/base/relu_infinilm.h | 60 ++++++++++++ .../cuda/iluvatar/ops/relu_infinilm/kernel.h | 21 +++++ .../cuda/metax/ops/relu_infinilm/kernel.h | 21 +++++ .../cuda/moore/ops/relu_infinilm/kernel.h | 22 +++++ .../cuda/nvidia/ops/relu_infinilm/kernel.h | 21 +++++ src/native/cuda/ops/relu_infinilm/kernel.cuh | 49 ++++++++++ src/native/cuda/ops/relu_infinilm/kernel.h | 91 +++++++++++++++++++ tests/test_relu_infinilm.py | 68 ++++++++++++++ 8 files changed, 353 insertions(+) create mode 100644 src/base/relu_infinilm.h create mode 100644 src/native/cuda/iluvatar/ops/relu_infinilm/kernel.h create mode 100644 src/native/cuda/metax/ops/relu_infinilm/kernel.h create mode 100644 src/native/cuda/moore/ops/relu_infinilm/kernel.h create mode 100644 src/native/cuda/nvidia/ops/relu_infinilm/kernel.h create mode 100644 src/native/cuda/ops/relu_infinilm/kernel.cuh create mode 100644 src/native/cuda/ops/relu_infinilm/kernel.h create mode 100644 tests/test_relu_infinilm.py diff --git a/src/base/relu_infinilm.h b/src/base/relu_infinilm.h new file mode 100644 index 000000000..c2b223aaf --- /dev/null +++ b/src/base/relu_infinilm.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_BASE_RELU_INFINILM_H_ +#define INFINI_OPS_BASE_RELU_INFINILM_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class ReluInfinilm : public Operator { + public: + ReluInfinilm(const Tensor input, Tensor out) + : input_shape_{input.shape()}, + input_strides_{input.strides()}, + input_type_{input.dtype()}, + out_shape_{out.shape()}, + out_strides_{out.strides()}, + out_type_{out.dtype()}, + output_size_{out.numel()}, + ndim_{out.ndim()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()}, + device_index_{out.device().index()} { + assert(input_shape_ == out_shape_ && + "`ReluInfinilm` input and output shapes must match"); + assert(input_type_ == out_type_ && + "`ReluInfinilm` input and output dtypes must match"); + assert(!out.HasBroadcastDim() && + "`ReluInfinilm` output must not have broadcasted dimensions"); + } + + virtual void operator()(const Tensor input, Tensor out) const = 0; + + protected: + Tensor::Shape input_shape_; + + Tensor::Strides input_strides_; + + DataType input_type_; + + Tensor::Shape out_shape_; + + Tensor::Strides out_strides_; + + DataType out_type_; + + Tensor::Size output_size_{0}; + + Tensor::Size ndim_{0}; + + bool is_input_contiguous_{false}; + + bool is_out_contiguous_{false}; + + int device_index_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/iluvatar/ops/relu_infinilm/kernel.h b/src/native/cuda/iluvatar/ops/relu_infinilm/kernel.h new file mode 100644 index 000000000..faf1c0790 --- /dev/null +++ b/src/native/cuda/iluvatar/ops/relu_infinilm/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_ILUVATAR_RELU_INFINILM_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_RELU_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/iluvatar/caster.cuh" +#include "native/cuda/iluvatar/runtime_.h" +#include "native/cuda/ops/relu_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaReluInfinilm> { + public: + using CudaReluInfinilm>::CudaReluInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/metax/ops/relu_infinilm/kernel.h b/src/native/cuda/metax/ops/relu_infinilm/kernel.h new file mode 100644 index 000000000..7e0aa3af5 --- /dev/null +++ b/src/native/cuda/metax/ops/relu_infinilm/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_METAX_RELU_INFINILM_KERNEL_H_ +#define INFINI_OPS_METAX_RELU_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/metax/caster.cuh" +#include "native/cuda/metax/runtime_.h" +#include "native/cuda/ops/relu_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaReluInfinilm> { + public: + using CudaReluInfinilm>::CudaReluInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/moore/ops/relu_infinilm/kernel.h b/src/native/cuda/moore/ops/relu_infinilm/kernel.h new file mode 100644 index 000000000..4de90b8fa --- /dev/null +++ b/src/native/cuda/moore/ops/relu_infinilm/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_MOORE_RELU_INFINILM_KERNEL_H_ +#define INFINI_OPS_MOORE_RELU_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/moore/caster.cuh" +#include "native/cuda/moore/polyfills.cuh" +#include "native/cuda/moore/runtime_.h" +#include "native/cuda/ops/relu_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaReluInfinilm> { + public: + using CudaReluInfinilm>::CudaReluInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/nvidia/ops/relu_infinilm/kernel.h b/src/native/cuda/nvidia/ops/relu_infinilm/kernel.h new file mode 100644 index 000000000..37f140caa --- /dev/null +++ b/src/native/cuda/nvidia/ops/relu_infinilm/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_NVIDIA_RELU_INFINILM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_RELU_INFINILM_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/relu_infinilm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaReluInfinilm> { + public: + using CudaReluInfinilm>::CudaReluInfinilm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/relu_infinilm/kernel.cuh b/src/native/cuda/ops/relu_infinilm/kernel.cuh new file mode 100644 index 000000000..b933d827d --- /dev/null +++ b/src/native/cuda/ops/relu_infinilm/kernel.cuh @@ -0,0 +1,49 @@ +#ifndef INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_CUH_ + +#include + +#include "native/cuda/caster.cuh" +#include "native/cuda/kernel_commons.cuh" + +namespace infini::ops { + +namespace { + +template +__device__ __forceinline__ T ReluInfinilmValue(T x) { + const float v = Caster::template Cast(x); + return Caster::template Cast(v > 0.0f ? v : 0.0f); +} + +template +__device__ __forceinline__ double ReluInfinilmValue(double x) { + return x > 0.0 ? x : 0.0; +} + +} // namespace + +template +__global__ void ReluInfinilmKernel(T* __restrict__ out, + const T* __restrict__ input, + const size_t* __restrict__ out_shape, + const size_t* __restrict__ input_shape, + const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ input_strides, + size_t output_size, size_t ndim, + bool out_contiguous, bool input_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < output_size) { + size_t out_idx = + out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides); + size_t input_idx = + input_contiguous ? idx + : IndexToOffset(idx, ndim, input_shape, input_strides); + out[out_idx] = ReluInfinilmValue(input[input_idx]); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/relu_infinilm/kernel.h b/src/native/cuda/ops/relu_infinilm/kernel.h new file mode 100644 index 000000000..975fcb82c --- /dev/null +++ b/src/native/cuda/ops/relu_infinilm/kernel.h @@ -0,0 +1,91 @@ +#ifndef INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_H_ +#define INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_H_ + +#include +#include +#include +#include + +#include "base/relu_infinilm.h" +#include "common/generic_utils.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/relu_infinilm/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaReluInfinilm : public ReluInfinilm { + public: + CudaReluInfinilm(const Tensor input, Tensor out) : ReluInfinilm{input, out} { + size_t shape_size = ndim_ * sizeof(*d_input_shape_); + size_t strides_size = ndim_ * sizeof(*d_input_strides_); + const size_t metadata_size = 2 * (shape_size + strides_size); + std::vector metadata(metadata_size); + + Backend::Malloc((void**)&d_metadata_, metadata_size); + + size_t offset = 0; + d_input_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_shape_.data(), shape_size); + offset += shape_size; + + d_out_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size); + offset += shape_size; + + d_input_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size); + offset += strides_size; + + d_out_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size); + + Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, + Backend::MemcpyHostToDevice); + } + + ~CudaReluInfinilm() { Backend::Free(d_metadata_); } + + void operator()(const Tensor input, Tensor out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + int block_size = std::min( + RuntimeUtils::GetOptimalBlockSize(), 1024); + dim3 block(std::min(static_cast(block_size), output_size_)); + dim3 grid(utils::CeilDiv(output_size_, block.x)); + + DispatchFunc>( + {static_cast(out_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + ReluInfinilmKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), d_out_shape_, + d_input_shape_, d_out_strides_, d_input_strides_, + output_size_, ndim_, is_out_contiguous_, + is_input_contiguous_); + }, + "CudaReluInfinilm::operator()"); + } + + private: + std::byte* d_metadata_{nullptr}; + + Tensor::Size* d_input_shape_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_input_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_relu_infinilm.py b/tests/test_relu_infinilm.py new file mode 100644 index 000000000..2ac23c5d0 --- /dev/null +++ b/tests/test_relu_infinilm.py @@ -0,0 +1,68 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, rand_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides, inplace", + ( + ((1, 3), None, None, False), + ((1, 3), None, None, True), + ((3, 3), None, None, False), + ((3, 3), (5, 1), (5, 1), False), + ((32, 20, 512), None, None, False), + ((32, 20, 512), None, None, True), + ((33, 333, 333), None, None, False), + ((32, 256, 112, 112), None, None, False), + ((3, 3, 13, 9, 17), None, None, False), + ( + (3, 3, 13, 9, 17), + (19890, 6630, 510, 34, 1), + (19890, 6630, 510, 34, 1), + False, + ), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-3, 1e-3), + ), +) +def test_relu_infinilm( + shape, input_strides, out_strides, inplace, dtype, device, rtol, atol +): + input = rand_strided(shape, input_strides, dtype=dtype, device=device) + input.mul_(2).sub_(1) + out = ( + input + if inplace + else empty_strided(shape, out_strides, dtype=dtype, device=device) + ) + + return Payload( + _relu_infinilm, + _torch_relu_infinilm, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _relu_infinilm(input, out): + infini.ops.relu_infinilm(input, out, stream=get_stream(input.device)) + + return out + + +def _torch_relu_infinilm(input, out): + result = torch.nn.functional.relu(input) + out.copy_(result) + + return out