Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/infiniop/ops/layer_norm/bang/layer_norm_bang.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __LAYER_NORM_BANG_H__
#define __LAYER_NORM_BANG_H__

#include "../layer_norm.h"

DESCRIPTOR(bang)

#endif // __LAYER_NORM_BANG_H__
288 changes: 288 additions & 0 deletions src/infiniop/ops/layer_norm/bang/layer_norm_bang.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
#include "../../../devices/bang/common_bang.h"
#include "../../../reduce/bang/reduce_bang.h"
#include "layer_norm_bang.h"

#include <algorithm>
#include <type_traits>

__nram__ char nram_buffer[NRAM_MAX_SIZE];

template <typename T>
__mlu_device__ void loadToFloat(float *dst, T *cache, const T *src, size_t n) {
__memcpy(cache, src, n * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same<T, half>::value) {
__bang_half2float(dst, cache, n);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, cache, n);
} else {
__memcpy(dst, cache, n * sizeof(float), NRAM2NRAM);
}
}

template <typename T>
__mlu_device__ void storeFromFloat(T *dst, T *cache, float *src, size_t n) {
if constexpr (std::is_same<T, half>::value) {
__bang_float2half(cache, src, n);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16(cache, src, n);
} else {
__memcpy(cache, src, n * sizeof(float), NRAM2NRAM);
}
__memcpy(dst, cache, n * sizeof(T), NRAM2GDRAM);
}

template <typename T>
__mlu_device__ void storeScalarFromFloat(T *dst, T *cache, float value) {
float *float_cache = reinterpret_cast<float *>(cache);
float_cache[0] = value;
storeFromFloat<T>(dst, cache, float_cache, 1);
}

__mlu_device__ int getMaxBatchSize(size_t dim, size_t data_size, size_t weight_size) {
constexpr int reduce_buffer_size = 128 / sizeof(float);
int max_batch_size = (NRAM_MAX_SIZE - 256 - reduce_buffer_size * sizeof(float)) / (4 * data_size + 2 * weight_size + 4 * sizeof(float));
max_batch_size = std::min(max_batch_size, static_cast<int>(dim));
if (max_batch_size > 64) {
max_batch_size = (max_batch_size / 64) * 64;
}
return std::max(max_batch_size, 1);
}

template <typename Tdata, typename Tweight>
__mlu_global__ void layerNormKernel(
Tdata *__restrict__ output,
Tdata *__restrict__ standardization,
Tdata *__restrict__ std_deviation,
const Tdata *__restrict__ input,
const Tweight *__restrict__ weight,
const Tweight *__restrict__ bias,
const size_t *__restrict__ shape,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ standardization_strides,
const ptrdiff_t *__restrict__ std_deviation_strides,
const ptrdiff_t *__restrict__ input_strides,
size_t rows,
size_t dim,
int ndim,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
float eps,
bool has_bias) {

constexpr int reduce_buffer_size = 128 / sizeof(float);
int max_batch_size = getMaxBatchSize(dim, sizeof(Tdata), sizeof(Tweight));

float *reduction_buffer = reinterpret_cast<float *>(nram_buffer);
Tdata *input_cache = reinterpret_cast<Tdata *>(reduction_buffer + reduce_buffer_size);
Tdata *output_cache = input_cache + max_batch_size;
Tdata *std_cache = output_cache + max_batch_size;
Tweight *weight_cache = reinterpret_cast<Tweight *>(std_cache + max_batch_size);
Tweight *bias_cache = weight_cache + max_batch_size;
float *input_float = reinterpret_cast<float *>(bias_cache + max_batch_size);
float *weight_float = input_float + max_batch_size;
float *bias_float = weight_float + max_batch_size;

for (size_t row = taskId; row < rows; row += taskDim) {
ptrdiff_t input_offset = 0;
ptrdiff_t output_offset = 0;
ptrdiff_t standardization_offset = 0;
ptrdiff_t std_deviation_offset = 0;
size_t tmp = row;
for (int axis = ndim - 2; axis >= 0; --axis) {
size_t coord = tmp % shape[axis];
tmp /= shape[axis];
input_offset += static_cast<ptrdiff_t>(coord) * input_strides[axis];
output_offset += static_cast<ptrdiff_t>(coord) * output_strides[axis];
standardization_offset += static_cast<ptrdiff_t>(coord) * standardization_strides[axis];
std_deviation_offset += static_cast<ptrdiff_t>(coord) * std_deviation_strides[axis];
}

const Tdata *input_row = input + input_offset;
Tdata *output_row = output + output_offset;
Tdata *standardization_row = standardization + standardization_offset;

float sum = 0.0f;
size_t processed = 0;
while (processed < dim) {
size_t current = std::min(static_cast<size_t>(max_batch_size), dim - processed);
loadToFloat<Tdata>(input_float, input_cache, input_row + processed * input_strides[ndim - 1], current);
if (current >= 128) {
op::common_bang::reduce_op::sumInternal(reduction_buffer, input_float, current);
sum += reduction_buffer[0];
} else {
for (size_t i = 0; i < current; ++i) {
sum += input_float[i];
}
}
processed += current;
}

float mean = sum / static_cast<float>(dim);
float variance_sum = 0.0f;
processed = 0;
while (processed < dim) {
size_t current = std::min(static_cast<size_t>(max_batch_size), dim - processed);
loadToFloat<Tdata>(input_float, input_cache, input_row + processed * input_strides[ndim - 1], current);
__bang_sub_scalar(input_float, input_float, mean, current);
__bang_mul(input_float, input_float, input_float, current);
if (current >= 128) {
op::common_bang::reduce_op::sumInternal(reduction_buffer, input_float, current);
variance_sum += reduction_buffer[0];
} else {
for (size_t i = 0; i < current; ++i) {
variance_sum += input_float[i];
}
}
processed += current;
}

float std_value = sqrtf(variance_sum / static_cast<float>(dim) + eps);
float inv_std = 1.0f / std_value;
storeScalarFromFloat<Tdata>(std_deviation + std_deviation_offset, std_cache, std_value);

processed = 0;
while (processed < dim) {
size_t current = std::min(static_cast<size_t>(max_batch_size), dim - processed);
loadToFloat<Tdata>(input_float, input_cache, input_row + processed * input_strides[ndim - 1], current);
loadToFloat<Tweight>(weight_float, weight_cache, weight + processed * weight_stride, current);
if (has_bias) {
loadToFloat<Tweight>(bias_float, bias_cache, bias + processed * bias_stride, current);
}

__bang_sub_scalar(input_float, input_float, mean, current);
__bang_mul_scalar(input_float, input_float, inv_std, current);
storeFromFloat<Tdata>(standardization_row + processed * standardization_strides[ndim - 1], std_cache, input_float, current);

__bang_mul(input_float, input_float, weight_float, current);
if (has_bias) {
__bang_add(input_float, input_float, bias_float, current);
}
storeFromFloat<Tdata>(output_row + processed * output_strides[ndim - 1], output_cache, input_float, current);
processed += current;
}
}
}

template <typename Tdata, typename Tweight>
void launchLayerNorm(
int core_per_cluster,
int cluster_count,
cnrtQueue_t queue,
void *output,
void *standardization,
void *std_deviation,
const void *input,
const void *weight,
const void *bias,
void *workspace,
const op::layer_norm::LayerNormInfo &info) {
cnrtDim3_t kernel_dim;
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;

char *tmp_device = reinterpret_cast<char *>(workspace);
size_t *mlu_shape = reinterpret_cast<size_t *>(tmp_device);
ptrdiff_t *mlu_output_strides = reinterpret_cast<ptrdiff_t *>(mlu_shape + info.ndim);
ptrdiff_t *mlu_standardization_strides = mlu_output_strides + info.ndim;
ptrdiff_t *mlu_std_deviation_strides = mlu_standardization_strides + info.ndim;
ptrdiff_t *mlu_input_strides = mlu_std_deviation_strides + (info.ndim - 1);

CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t *>(info.input_shape.data()), info.ndim * sizeof(size_t), queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpyAsync(mlu_output_strides, const_cast<ptrdiff_t *>(info.output_strides.data()), info.ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpyAsync(mlu_standardization_strides, const_cast<ptrdiff_t *>(info.input_standardization_strides.data()), info.ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpyAsync(mlu_std_deviation_strides, const_cast<ptrdiff_t *>(info.input_std_deviation_strides.data()), (info.ndim - 1) * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpyAsync(mlu_input_strides, const_cast<ptrdiff_t *>(info.input_strides.data()), info.ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));

layerNormKernel<Tdata, Tweight><<<kernel_dim, cnrtFuncTypeUnion1, queue>>>(
reinterpret_cast<Tdata *>(output),
reinterpret_cast<Tdata *>(standardization),
reinterpret_cast<Tdata *>(std_deviation),
reinterpret_cast<const Tdata *>(input),
reinterpret_cast<const Tweight *>(weight),
reinterpret_cast<const Tweight *>(bias),
mlu_shape,
mlu_output_strides,
mlu_standardization_strides,
mlu_std_deviation_strides,
mlu_input_strides,
info.othersize,
info.normalized_size,
static_cast<int>(info.ndim),
info.weight_strides.back(),
info.bias_exist ? info.bias_strides.back() : 0,
info.eps,
info.bias_exist);
cnrtQueueSync(queue);
}

namespace op::layer_norm::bang {

struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_standardization_desc,
infiniopTensorDescriptor_t input_std_deviation_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t bias_desc,
float eps) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto result = LayerNormInfo::createLayerNormInfo(
output_desc,
input_standardization_desc,
input_std_deviation_desc,
input_desc,
weight_desc,
bias_desc,
eps);
CHECK_RESULT(result);
auto info = result.take();

CHECK_DTYPE(info.dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
*desc_ptr = new Descriptor(
info.dtype,
info,
info.ndim * sizeof(size_t) + (4 * info.ndim - 1) * sizeof(ptrdiff_t),
new Descriptor::Opaque{handle->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
void *input_standardization,
void *input_std_deviation,
const void *input,
const void *weight,
const void *bias,
void *stream) const {
auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();

if (_info.dtype == INFINI_DTYPE_F16) {
launchLayerNorm<half, half>(core_per_cluster, cluster_count, queue, output, input_standardization, input_std_deviation, input, weight, bias, workspace, _info);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
launchLayerNorm<bfloat16_t, bfloat16_t>(core_per_cluster, cluster_count, queue, output, input_standardization, input_std_deviation, input, weight, bias, workspace, _info);
} else if (_info.dtype == INFINI_DTYPE_F32) {
launchLayerNorm<float, float>(core_per_cluster, cluster_count, queue, output, input_standardization, input_std_deviation, input, weight, bias, workspace, _info);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}

} // namespace op::layer_norm::bang
15 changes: 15 additions & 0 deletions src/infiniop/ops/layer_norm/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#ifdef ENABLE_MOORE_API
#include "moore/layer_norm_moore.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/layer_norm_bang.h"
#endif

__INFINI_C infiniStatus_t infiniopCreateLayerNormDescriptor(
infiniopHandle_t handle,
Expand Down Expand Up @@ -64,6 +67,9 @@ __INFINI_C infiniStatus_t infiniopCreateLayerNormDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -102,6 +108,9 @@ __INFINI_C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDes
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -162,6 +171,9 @@ __INFINI_C infiniStatus_t infiniopLayerNorm(
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -201,6 +213,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
Expand Down
6 changes: 4 additions & 2 deletions test/infiniop/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def torch_layer_norm(
var = input.var(dim=-1, correction=0)
std = torch.sqrt(var + eps)
input_standardization.copy_(
((input - mean) / std.unsqueeze(2)).type(input_standardization.dtype)
((input - mean) / std.unsqueeze(-1)).type(input_standardization.dtype)
)
input_std_deviation.copy_(std.type(input_standardization.dtype))
output.copy_(ln(input).detach().type(output.dtype))
Expand Down Expand Up @@ -179,8 +179,10 @@ def test(
else None
)

layer_norm(
torch_layer_norm(
output.torch_tensor(),
input_standardization.torch_tensor(),
input_std_deviation.torch_tensor(),
input.torch_tensor(),
weight.torch_tensor(),
bias.torch_tensor() if bias_exist else None,
Expand Down
Loading