Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/infiniop/elementwise/bang/elementwise_bang_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ __mlu_device__ void launchOp(
// 2. Execute operation
Tdata *output_buffer = aligned_buf + N * max_batch;
Op op;
op(output_buffer, input_buffers[0], input_buffers[1], curr_batch, args...);
if constexpr (N == 1) {
op(output_buffer, input_buffers[0], input_buffers[0], curr_batch, args...);
} else {
op(output_buffer, input_buffers[0], input_buffers[1], curr_batch, args...);
}
__sync_compute();

// 3. Write back results
Expand Down
8 changes: 8 additions & 0 deletions src/infiniop/ops/gelu/bang/gelu_bang.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __GELU_BANG_H__
#define __GELU_BANG_H__

#include "../../../elementwise/bang/elementwise_bang.h"

ELEMENTWISE_DESCRIPTOR(gelu, bang)

#endif // __GELU_BANG_H__
56 changes: 56 additions & 0 deletions src/infiniop/ops/gelu/bang/gelu_bang.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "gelu_bang.h"

LAUNCH_ELEMENTWISE_KERNEL(Gelu)

namespace op::gelu::bang {

typedef struct GeluOp {
static constexpr size_t num_inputs = 1;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchGeluKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} GeluOp;

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &input_desc = input_desc_vec.at(0);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(out_desc->shape(), input_desc->shape());

CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *queue) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<GeluOp, half>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_BF16:
return _device_info->calculate<GeluOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<GeluOp, float>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}

} // namespace op::gelu::bang
31 changes: 31 additions & 0 deletions src/infiniop/ops/gelu/bang/gelu_bang_internal.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef __GELU_BANG_INTERNAL_H__
#define __GELU_BANG_INTERNAL_H__

#include "../../../elementwise/bang/elementwise_bang_kernel.h"
#include "bang.h"
#include "bang_device_functions.h"
#include <cmath>
#include <type_traits>

typedef struct GeluOp {
public:
static constexpr size_t num_inputs = 1;
template <typename T>
__mlu_device__ void operator()(T *out, const T *input, const T *unused, size_t num_elements) const {
(void)unused;
constexpr float inv_sqrt2 = 0.70710678118654752440f;
for (size_t i = 0; i < num_elements; ++i) {
float x = static_cast<float>(input[i]);
float y = 0.5f * x * (1.0f + erff(x * inv_sqrt2));
out[i] = static_cast<T>(y);
}
}
} GeluOp;

LAUNCH_ELEMENTWISE_KERNEL_IMPL(Gelu, GeluOp)

LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Gelu, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Gelu, bfloat16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Gelu, float)

#endif // __GELU_BANG_INTERNAL_H__
15 changes: 15 additions & 0 deletions src/infiniop/ops/gelu/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gelu_kunlun.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/gelu_bang.h"
#endif

__INFINI_C infiniStatus_t infiniopCreateGeluDescriptor(
infiniopHandle_t handle,
Expand Down Expand Up @@ -52,6 +55,9 @@ __INFINI_C infiniStatus_t infiniopCreateGeluDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
Expand Down Expand Up @@ -92,6 +98,9 @@ __INFINI_C infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
Expand Down Expand Up @@ -140,6 +149,9 @@ __INFINI_C infiniStatus_t infiniopGelu(
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
Expand Down Expand Up @@ -182,6 +194,9 @@ infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
Expand Down
6 changes: 4 additions & 2 deletions test/infiniop/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
# shape, input_stride, output_stride
((13, 4), None, None),
((13, 4), (10, 1), (10, 1)),
#((13, 4), (0, 1), None),
# ((13, 4), (0, 1), None),
((13, 4, 4), None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1)),
#((13, 4, 4), (4, 0, 1), None),
# ((13, 4, 4), (4, 0, 1), None),
((16, 5632), None, None),
((16, 5632), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None),
Expand Down Expand Up @@ -89,6 +89,8 @@ def test(
input_stride is not None or output_stride is not None
):
return
if device == InfiniDeviceEnum.CAMBRICON and dtype == InfiniDtype.F64:
return

input = TestTensor(shape, input_stride, dtype, device)
if inplace == Inplace.INPLACE:
Expand Down
Loading