diff --git a/src/infiniop/ops/layer_norm/bang/layer_norm_bang.h b/src/infiniop/ops/layer_norm/bang/layer_norm_bang.h new file mode 100644 index 000000000..7c686bfb8 --- /dev/null +++ b/src/infiniop/ops/layer_norm/bang/layer_norm_bang.h @@ -0,0 +1,8 @@ +#ifndef __LAYER_NORM_BANG_H__ +#define __LAYER_NORM_BANG_H__ + +#include "../layer_norm.h" + +DESCRIPTOR(bang) + +#endif // __LAYER_NORM_BANG_H__ diff --git a/src/infiniop/ops/layer_norm/bang/layer_norm_bang.mlu b/src/infiniop/ops/layer_norm/bang/layer_norm_bang.mlu new file mode 100644 index 000000000..d0379dce9 --- /dev/null +++ b/src/infiniop/ops/layer_norm/bang/layer_norm_bang.mlu @@ -0,0 +1,288 @@ +#include "../../../devices/bang/common_bang.h" +#include "../../../reduce/bang/reduce_bang.h" +#include "layer_norm_bang.h" + +#include +#include + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_device__ void loadToFloat(float *dst, T *cache, const T *src, size_t n) { + __memcpy(cache, src, n * sizeof(T), GDRAM2NRAM); + if constexpr (std::is_same::value) { + __bang_half2float(dst, cache, n); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(dst, cache, n); + } else { + __memcpy(dst, cache, n * sizeof(float), NRAM2NRAM); + } +} + +template +__mlu_device__ void storeFromFloat(T *dst, T *cache, float *src, size_t n) { + if constexpr (std::is_same::value) { + __bang_float2half(cache, src, n); + } else if constexpr (std::is_same::value) { + __bang_float2bfloat16(cache, src, n); + } else { + __memcpy(cache, src, n * sizeof(float), NRAM2NRAM); + } + __memcpy(dst, cache, n * sizeof(T), NRAM2GDRAM); +} + +template +__mlu_device__ void storeScalarFromFloat(T *dst, T *cache, float value) { + float *float_cache = reinterpret_cast(cache); + float_cache[0] = value; + storeFromFloat(dst, cache, float_cache, 1); +} + +__mlu_device__ int getMaxBatchSize(size_t dim, size_t data_size, size_t weight_size) { + constexpr int reduce_buffer_size = 128 / sizeof(float); + int max_batch_size = (NRAM_MAX_SIZE - 256 - reduce_buffer_size * sizeof(float)) / (4 * data_size + 2 * weight_size + 4 * sizeof(float)); + max_batch_size = std::min(max_batch_size, static_cast(dim)); + if (max_batch_size > 64) { + max_batch_size = (max_batch_size / 64) * 64; + } + return std::max(max_batch_size, 1); +} + +template +__mlu_global__ void layerNormKernel( + Tdata *__restrict__ output, + Tdata *__restrict__ standardization, + Tdata *__restrict__ std_deviation, + const Tdata *__restrict__ input, + const Tweight *__restrict__ weight, + const Tweight *__restrict__ bias, + const size_t *__restrict__ shape, + const ptrdiff_t *__restrict__ output_strides, + const ptrdiff_t *__restrict__ standardization_strides, + const ptrdiff_t *__restrict__ std_deviation_strides, + const ptrdiff_t *__restrict__ input_strides, + size_t rows, + size_t dim, + int ndim, + ptrdiff_t weight_stride, + ptrdiff_t bias_stride, + float eps, + bool has_bias) { + + constexpr int reduce_buffer_size = 128 / sizeof(float); + int max_batch_size = getMaxBatchSize(dim, sizeof(Tdata), sizeof(Tweight)); + + float *reduction_buffer = reinterpret_cast(nram_buffer); + Tdata *input_cache = reinterpret_cast(reduction_buffer + reduce_buffer_size); + Tdata *output_cache = input_cache + max_batch_size; + Tdata *std_cache = output_cache + max_batch_size; + Tweight *weight_cache = reinterpret_cast(std_cache + max_batch_size); + Tweight *bias_cache = weight_cache + max_batch_size; + float *input_float = reinterpret_cast(bias_cache + max_batch_size); + float *weight_float = input_float + max_batch_size; + float *bias_float = weight_float + max_batch_size; + + for (size_t row = taskId; row < rows; row += taskDim) { + ptrdiff_t input_offset = 0; + ptrdiff_t output_offset = 0; + ptrdiff_t standardization_offset = 0; + ptrdiff_t std_deviation_offset = 0; + size_t tmp = row; + for (int axis = ndim - 2; axis >= 0; --axis) { + size_t coord = tmp % shape[axis]; + tmp /= shape[axis]; + input_offset += static_cast(coord) * input_strides[axis]; + output_offset += static_cast(coord) * output_strides[axis]; + standardization_offset += static_cast(coord) * standardization_strides[axis]; + std_deviation_offset += static_cast(coord) * std_deviation_strides[axis]; + } + + const Tdata *input_row = input + input_offset; + Tdata *output_row = output + output_offset; + Tdata *standardization_row = standardization + standardization_offset; + + float sum = 0.0f; + size_t processed = 0; + while (processed < dim) { + size_t current = std::min(static_cast(max_batch_size), dim - processed); + loadToFloat(input_float, input_cache, input_row + processed * input_strides[ndim - 1], current); + if (current >= 128) { + op::common_bang::reduce_op::sumInternal(reduction_buffer, input_float, current); + sum += reduction_buffer[0]; + } else { + for (size_t i = 0; i < current; ++i) { + sum += input_float[i]; + } + } + processed += current; + } + + float mean = sum / static_cast(dim); + float variance_sum = 0.0f; + processed = 0; + while (processed < dim) { + size_t current = std::min(static_cast(max_batch_size), dim - processed); + loadToFloat(input_float, input_cache, input_row + processed * input_strides[ndim - 1], current); + __bang_sub_scalar(input_float, input_float, mean, current); + __bang_mul(input_float, input_float, input_float, current); + if (current >= 128) { + op::common_bang::reduce_op::sumInternal(reduction_buffer, input_float, current); + variance_sum += reduction_buffer[0]; + } else { + for (size_t i = 0; i < current; ++i) { + variance_sum += input_float[i]; + } + } + processed += current; + } + + float std_value = sqrtf(variance_sum / static_cast(dim) + eps); + float inv_std = 1.0f / std_value; + storeScalarFromFloat(std_deviation + std_deviation_offset, std_cache, std_value); + + processed = 0; + while (processed < dim) { + size_t current = std::min(static_cast(max_batch_size), dim - processed); + loadToFloat(input_float, input_cache, input_row + processed * input_strides[ndim - 1], current); + loadToFloat(weight_float, weight_cache, weight + processed * weight_stride, current); + if (has_bias) { + loadToFloat(bias_float, bias_cache, bias + processed * bias_stride, current); + } + + __bang_sub_scalar(input_float, input_float, mean, current); + __bang_mul_scalar(input_float, input_float, inv_std, current); + storeFromFloat(standardization_row + processed * standardization_strides[ndim - 1], std_cache, input_float, current); + + __bang_mul(input_float, input_float, weight_float, current); + if (has_bias) { + __bang_add(input_float, input_float, bias_float, current); + } + storeFromFloat(output_row + processed * output_strides[ndim - 1], output_cache, input_float, current); + processed += current; + } + } +} + +template +void launchLayerNorm( + int core_per_cluster, + int cluster_count, + cnrtQueue_t queue, + void *output, + void *standardization, + void *std_deviation, + const void *input, + const void *weight, + const void *bias, + void *workspace, + const op::layer_norm::LayerNormInfo &info) { + cnrtDim3_t kernel_dim; + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + + char *tmp_device = reinterpret_cast(workspace); + size_t *mlu_shape = reinterpret_cast(tmp_device); + ptrdiff_t *mlu_output_strides = reinterpret_cast(mlu_shape + info.ndim); + ptrdiff_t *mlu_standardization_strides = mlu_output_strides + info.ndim; + ptrdiff_t *mlu_std_deviation_strides = mlu_standardization_strides + info.ndim; + ptrdiff_t *mlu_input_strides = mlu_std_deviation_strides + (info.ndim - 1); + + CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast(info.input_shape.data()), info.ndim * sizeof(size_t), queue, cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_output_strides, const_cast(info.output_strides.data()), info.ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_standardization_strides, const_cast(info.input_standardization_strides.data()), info.ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_std_deviation_strides, const_cast(info.input_std_deviation_strides.data()), (info.ndim - 1) * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_input_strides, const_cast(info.input_strides.data()), info.ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + + layerNormKernel<<>>( + reinterpret_cast(output), + reinterpret_cast(standardization), + reinterpret_cast(std_deviation), + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + mlu_shape, + mlu_output_strides, + mlu_standardization_strides, + mlu_std_deviation_strides, + mlu_input_strides, + info.othersize, + info.normalized_size, + static_cast(info.ndim), + info.weight_strides.back(), + info.bias_exist ? info.bias_strides.back() : 0, + info.eps, + info.bias_exist); + cnrtQueueSync(queue); +} + +namespace op::layer_norm::bang { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_standardization_desc, + infiniopTensorDescriptor_t input_std_deviation_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc, + float eps) { + auto handle = reinterpret_cast(handle_); + auto result = LayerNormInfo::createLayerNormInfo( + output_desc, + input_standardization_desc, + input_std_deviation_desc, + input_desc, + weight_desc, + bias_desc, + eps); + CHECK_RESULT(result); + auto info = result.take(); + + CHECK_DTYPE(info.dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + *desc_ptr = new Descriptor( + info.dtype, + info, + info.ndim * sizeof(size_t) + (4 * info.ndim - 1) * sizeof(ptrdiff_t), + new Descriptor::Opaque{handle->internal()}, + handle->device, + handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + void *input_standardization, + void *input_std_deviation, + const void *input, + const void *weight, + const void *bias, + void *stream) const { + auto queue = reinterpret_cast(stream); + int core_per_cluster = _opaque->internal->getCorePerCluster(); + int cluster_count = _opaque->internal->getClusterCount(); + + if (_info.dtype == INFINI_DTYPE_F16) { + launchLayerNorm(core_per_cluster, cluster_count, queue, output, input_standardization, input_std_deviation, input, weight, bias, workspace, _info); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + launchLayerNorm(core_per_cluster, cluster_count, queue, output, input_standardization, input_std_deviation, input, weight, bias, workspace, _info); + } else if (_info.dtype == INFINI_DTYPE_F32) { + launchLayerNorm(core_per_cluster, cluster_count, queue, output, input_standardization, input_std_deviation, input, weight, bias, workspace, _info); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::layer_norm::bang diff --git a/src/infiniop/ops/layer_norm/operator.cc b/src/infiniop/ops/layer_norm/operator.cc index 65e7aecf0..3925e845a 100644 --- a/src/infiniop/ops/layer_norm/operator.cc +++ b/src/infiniop/ops/layer_norm/operator.cc @@ -14,6 +14,9 @@ #ifdef ENABLE_MOORE_API #include "moore/layer_norm_moore.h" #endif +#ifdef ENABLE_CAMBRICON_API +#include "bang/layer_norm_bang.h" +#endif __INFINI_C infiniStatus_t infiniopCreateLayerNormDescriptor( infiniopHandle_t handle, @@ -64,6 +67,9 @@ __INFINI_C infiniStatus_t infiniopCreateLayerNormDescriptor( #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -102,6 +108,9 @@ __INFINI_C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDes #endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -162,6 +171,9 @@ __INFINI_C infiniStatus_t infiniopLayerNorm( #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -201,6 +213,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) { #ifdef ENABLE_MOORE_API DELETE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif #ifdef ENABLE_ILUVATAR_API DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif diff --git a/test/infiniop/layer_norm.py b/test/infiniop/layer_norm.py index 85a360248..b675e9224 100644 --- a/test/infiniop/layer_norm.py +++ b/test/infiniop/layer_norm.py @@ -95,7 +95,7 @@ def torch_layer_norm( var = input.var(dim=-1, correction=0) std = torch.sqrt(var + eps) input_standardization.copy_( - ((input - mean) / std.unsqueeze(2)).type(input_standardization.dtype) + ((input - mean) / std.unsqueeze(-1)).type(input_standardization.dtype) ) input_std_deviation.copy_(std.type(input_standardization.dtype)) output.copy_(ln(input).detach().type(output.dtype)) @@ -179,8 +179,10 @@ def test( else None ) - layer_norm( + torch_layer_norm( output.torch_tensor(), + input_standardization.torch_tensor(), + input_std_deviation.torch_tensor(), input.torch_tensor(), weight.torch_tensor(), bias.torch_tensor() if bias_exist else None,