Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/base/relu_infinilm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#ifndef INFINI_OPS_BASE_RELU_INFINILM_H_
#define INFINI_OPS_BASE_RELU_INFINILM_H_

#include <cassert>

#include "operator.h"

namespace infini::ops {

class ReluInfinilm : public Operator<ReluInfinilm> {
public:
ReluInfinilm(const Tensor input, Tensor out)
: input_shape_{input.shape()},
input_strides_{input.strides()},
input_type_{input.dtype()},
out_shape_{out.shape()},
out_strides_{out.strides()},
out_type_{out.dtype()},
output_size_{out.numel()},
ndim_{out.ndim()},
is_input_contiguous_{input.IsContiguous()},
is_out_contiguous_{out.IsContiguous()},
device_index_{out.device().index()} {
assert(input_shape_ == out_shape_ &&
"`ReluInfinilm` input and output shapes must match");
assert(input_type_ == out_type_ &&
"`ReluInfinilm` input and output dtypes must match");
assert(!out.HasBroadcastDim() &&
"`ReluInfinilm` output must not have broadcasted dimensions");
}

virtual void operator()(const Tensor input, Tensor out) const = 0;

protected:
Tensor::Shape input_shape_;

Tensor::Strides input_strides_;

DataType input_type_;

Tensor::Shape out_shape_;

Tensor::Strides out_strides_;

DataType out_type_;

Tensor::Size output_size_{0};

Tensor::Size ndim_{0};

bool is_input_contiguous_{false};

bool is_out_contiguous_{false};

int device_index_{0};
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/iluvatar/ops/relu_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_ILUVATAR_RELU_INFINILM_KERNEL_H_
#define INFINI_OPS_ILUVATAR_RELU_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/relu_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<ReluInfinilm, Device::Type::kIluvatar>
: public CudaReluInfinilm<Runtime<Device::Type::kIluvatar>> {
public:
using CudaReluInfinilm<Runtime<Device::Type::kIluvatar>>::CudaReluInfinilm;
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/metax/ops/relu_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_METAX_RELU_INFINILM_KERNEL_H_
#define INFINI_OPS_METAX_RELU_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/relu_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<ReluInfinilm, Device::Type::kMetax>
: public CudaReluInfinilm<Runtime<Device::Type::kMetax>> {
public:
using CudaReluInfinilm<Runtime<Device::Type::kMetax>>::CudaReluInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/moore/ops/relu_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_MOORE_RELU_INFINILM_KERNEL_H_
#define INFINI_OPS_MOORE_RELU_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/polyfills.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/relu_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<ReluInfinilm, Device::Type::kMoore>
: public CudaReluInfinilm<Runtime<Device::Type::kMoore>> {
public:
using CudaReluInfinilm<Runtime<Device::Type::kMoore>>::CudaReluInfinilm;
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/nvidia/ops/relu_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_NVIDIA_RELU_INFINILM_KERNEL_H_
#define INFINI_OPS_NVIDIA_RELU_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/relu_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<ReluInfinilm, Device::Type::kNvidia>
: public CudaReluInfinilm<Runtime<Device::Type::kNvidia>> {
public:
using CudaReluInfinilm<Runtime<Device::Type::kNvidia>>::CudaReluInfinilm;
};

} // namespace infini::ops

#endif
49 changes: 49 additions & 0 deletions src/native/cuda/ops/relu_infinilm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_CUH_
#define INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_CUH_

#include <cstddef>

#include "native/cuda/caster.cuh"
#include "native/cuda/kernel_commons.cuh"

namespace infini::ops {

namespace {

template <Device::Type kDev, typename T>
__device__ __forceinline__ T ReluInfinilmValue(T x) {
const float v = Caster<kDev>::template Cast<float>(x);
return Caster<kDev>::template Cast<T>(v > 0.0f ? v : 0.0f);
}

template <Device::Type kDev>
__device__ __forceinline__ double ReluInfinilmValue(double x) {
return x > 0.0 ? x : 0.0;
}

} // namespace

template <Device::Type kDev, typename T, unsigned int block_size>
__global__ void ReluInfinilmKernel(T* __restrict__ out,
const T* __restrict__ input,
const size_t* __restrict__ out_shape,
const size_t* __restrict__ input_shape,
const ptrdiff_t* __restrict__ out_strides,
const ptrdiff_t* __restrict__ input_strides,
size_t output_size, size_t ndim,
bool out_contiguous, bool input_contiguous) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx < output_size) {
size_t out_idx =
out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides);
size_t input_idx =
input_contiguous ? idx
: IndexToOffset(idx, ndim, input_shape, input_strides);
out[out_idx] = ReluInfinilmValue<kDev>(input[input_idx]);
}
}

} // namespace infini::ops

#endif
91 changes: 91 additions & 0 deletions src/native/cuda/ops/relu_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#ifndef INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_H_
#define INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_H_

#include <algorithm>
#include <cstddef>
#include <cstring>
#include <vector>

#include "base/relu_infinilm.h"
#include "common/generic_utils.h"
#include "data_type.h"
#include "dispatcher.h"
#include "native/cuda/kernel_commons.cuh"
#include "native/cuda/ops/relu_infinilm/kernel.cuh"
#include "native/cuda/runtime_utils.h"

namespace infini::ops {

template <typename Backend>
class CudaReluInfinilm : public ReluInfinilm {
public:
CudaReluInfinilm(const Tensor input, Tensor out) : ReluInfinilm{input, out} {
size_t shape_size = ndim_ * sizeof(*d_input_shape_);
size_t strides_size = ndim_ * sizeof(*d_input_strides_);
const size_t metadata_size = 2 * (shape_size + strides_size);
std::vector<std::byte> metadata(metadata_size);

Backend::Malloc((void**)&d_metadata_, metadata_size);

size_t offset = 0;
d_input_shape_ = reinterpret_cast<Tensor::Size*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, input_shape_.data(), shape_size);
offset += shape_size;

d_out_shape_ = reinterpret_cast<Tensor::Size*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size);
offset += shape_size;

d_input_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size);
offset += strides_size;

d_out_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size);

Backend::Memcpy(d_metadata_, metadata.data(), metadata_size,
Backend::MemcpyHostToDevice);
}

~CudaReluInfinilm() { Backend::Free(d_metadata_); }

void operator()(const Tensor input, Tensor out) const override {
auto cuda_stream =
static_cast<typename Backend::Stream>(stream_ ? stream_ : 0);
int block_size = std::min(
RuntimeUtils<Backend::kDeviceType>::GetOptimalBlockSize(), 1024);
dim3 block(std::min(static_cast<Tensor::Size>(block_size), output_size_));
dim3 grid(utils::CeilDiv(output_size_, block.x));

DispatchFunc<AllFloatTypes, List<128, 256, 512, 1024>>(
{static_cast<int64_t>(out_type_), block_size},
[&](auto list_tag) {
using T = TypeMapType<Backend::kDeviceType, ListGet<0>(list_tag)>;
constexpr int kBlockSize = ListGet<1>(list_tag);

ReluInfinilmKernel<Backend::kDeviceType, T, kBlockSize>
<<<grid, block, 0, cuda_stream>>>(
reinterpret_cast<T*>(out.data()),
reinterpret_cast<const T*>(input.data()), d_out_shape_,
d_input_shape_, d_out_strides_, d_input_strides_,
output_size_, ndim_, is_out_contiguous_,
is_input_contiguous_);
},
"CudaReluInfinilm::operator()");
}

private:
std::byte* d_metadata_{nullptr};

Tensor::Size* d_input_shape_{nullptr};

Tensor::Size* d_out_shape_{nullptr};

Tensor::Stride* d_input_strides_{nullptr};

Tensor::Stride* d_out_strides_{nullptr};
};

} // namespace infini::ops

#endif
68 changes: 68 additions & 0 deletions tests/test_relu_infinilm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import infini.ops
import pytest
import torch

from tests.utils import Payload, empty_strided, get_stream, rand_strided


@pytest.mark.auto_act_and_assert
@pytest.mark.parametrize(
"shape, input_strides, out_strides, inplace",
(
((1, 3), None, None, False),
((1, 3), None, None, True),
((3, 3), None, None, False),
((3, 3), (5, 1), (5, 1), False),
((32, 20, 512), None, None, False),
((32, 20, 512), None, None, True),
((33, 333, 333), None, None, False),
((32, 256, 112, 112), None, None, False),
((3, 3, 13, 9, 17), None, None, False),
(
(3, 3, 13, 9, 17),
(19890, 6630, 510, 34, 1),
(19890, 6630, 510, 34, 1),
False,
),
),
)
@pytest.mark.parametrize(
("dtype", "rtol", "atol"),
(
(torch.float32, 1e-7, 1e-7),
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-3, 1e-3),
),
)
def test_relu_infinilm(
shape, input_strides, out_strides, inplace, dtype, device, rtol, atol
):
input = rand_strided(shape, input_strides, dtype=dtype, device=device)
input.mul_(2).sub_(1)
out = (
input
if inplace
else empty_strided(shape, out_strides, dtype=dtype, device=device)
)

return Payload(
_relu_infinilm,
_torch_relu_infinilm,
(input, out),
{},
rtol=rtol,
atol=atol,
)


def _relu_infinilm(input, out):
infini.ops.relu_infinilm(input, out, stream=get_stream(input.device))

return out


def _torch_relu_infinilm(input, out):
result = torch.nn.functional.relu(input)
out.copy_(result)

return out
Loading