From bc31e64464ca0c5aa5d91c0cd37534c84e3a1b59 Mon Sep 17 00:00:00 2001 From: ShaneWu Date: Thu, 11 Jun 2026 16:27:43 +0800 Subject: [PATCH] fix rms_norm op to adapt with ascend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改rms_norm算子在昇腾ascend上对跨半精度数据格式和3D张量进行适配。 修改内容: 1、修改Infinicore/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc,当 w dtype 与 x dtype 不同且非 F32 时将 w cast 到 F32 再调用aclnnRmsNorm,并修复 3D 张量的 slice_shape 取值和循环遍历问题; 2、新增Infinicore/src/infiniop/ops/rms_norm/ascend/cast_kernel.cpp,实现 AscendC Cast 核函数(F16/BF16 → F32),用于跨半精度时 w dtype 的转换; 3、修改Infinicore/src/infiniop/devices/ascend/CMakeLists.txt,添加 cast_kernel.cpp 编译项。 现状:infiniop算子测试全部通过,infinicore算子接口测试跑通,103/108 Passed。 --- src/infiniop/devices/ascend/CMakeLists.txt | 1 + .../ops/rms_norm/ascend/cast_kernel.cpp | 108 +++++++++++++ .../ops/rms_norm/ascend/rms_norm_aclnn.cc | 142 ++++++++++++++---- 3 files changed, 219 insertions(+), 32 deletions(-) create mode 100644 src/infiniop/ops/rms_norm/ascend/cast_kernel.cpp diff --git a/src/infiniop/devices/ascend/CMakeLists.txt b/src/infiniop/devices/ascend/CMakeLists.txt index 12bfb3670..fad65d259 100644 --- a/src/infiniop/devices/ascend/CMakeLists.txt +++ b/src/infiniop/devices/ascend/CMakeLists.txt @@ -27,6 +27,7 @@ ascendc_library(ascend_kernels STATIC ../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp ../../ops/rope/ascend/rope_ascend_kernel.cpp ../../ops/random_sample/ascend/random_sample_kernel.cpp + ../../ops/rms_norm/ascend/cast_kernel.cpp ) target_include_directories(ascend_kernels PRIVATE ../../../../include) diff --git a/src/infiniop/ops/rms_norm/ascend/cast_kernel.cpp b/src/infiniop/ops/rms_norm/ascend/cast_kernel.cpp new file mode 100644 index 000000000..98808062e --- /dev/null +++ b/src/infiniop/ops/rms_norm/ascend/cast_kernel.cpp @@ -0,0 +1,108 @@ +#include "../../../devices/ascend/ascend_kernel_common.h" + +using namespace AscendC; + +template +class CastKernel { +public: + __aicore__ inline CastKernel() {} + __aicore__ inline void init(GM_ADDR dst, GM_ADDR src, size_t count); + __aicore__ inline void process(); + +private: + __aicore__ inline void copyIn(); + __aicore__ inline void compute(); + __aicore__ inline void copyOut(); + + GlobalTensor _src_gm; + GlobalTensor _dst_gm; + TQue _in_queue; + TQue _out_queue; + TPipe _pipe; + size_t _tile_len, _copy_len; +}; + +template +__aicore__ inline void CastKernel::init(GM_ADDR dst, GM_ADDR src, size_t count) { + _tile_len = count; + _copy_len = alignTileLen(_tile_len, BYTE_ALIGN); + + _dst_gm.SetGlobalBuffer((__gm__ DstT *)dst); + _src_gm.SetGlobalBuffer((__gm__ SrcT *)src); + + _pipe.InitBuffer(_in_queue, BUFFER_NUM, _copy_len * sizeof(SrcT)); + _pipe.InitBuffer(_out_queue, BUFFER_NUM, _copy_len * sizeof(DstT)); +} + +template +__aicore__ inline void CastKernel::copyIn() { + LocalTensor srcLocal = _in_queue.AllocTensor(); + DataCopy(srcLocal, _src_gm, _copy_len); + _in_queue.EnQue(srcLocal); +} + +template +__aicore__ inline void CastKernel::compute() { + LocalTensor srcLocal = _in_queue.DeQue(); + LocalTensor dstLocal = _out_queue.AllocTensor(); + Cast(dstLocal, srcLocal, RoundMode::CAST_NONE, _copy_len); + _out_queue.EnQue(dstLocal); + _in_queue.FreeTensor(srcLocal); +} + +template +__aicore__ inline void CastKernel::copyOut() { + LocalTensor dstLocal = _out_queue.DeQue(); + if (_tile_len * sizeof(DstT) % BYTE_ALIGN != 0) { + DataCopyExtParams dcep = {1, static_cast(_tile_len * sizeof(DstT)), 0, 0, 0}; + DataCopyPad(_dst_gm, dstLocal, dcep); + } else { + DataCopy(_dst_gm, dstLocal, _tile_len); + } + _out_queue.FreeTensor(dstLocal); +} + +template +__aicore__ inline void CastKernel::process() { + copyIn(); + compute(); + copyOut(); +} + +#define DEFINE_CAST_KERNEL(KERNEL_NAME, SRC_T, DST_T) \ + __global__ __aicore__ void KERNEL_NAME(GM_ADDR dst, GM_ADDR src, \ + size_t count) { \ + CastKernel op; \ + op.init(dst, src, count); \ + op.process(); \ + } + +DEFINE_CAST_KERNEL(cast_kernel_f16_to_f32, half, float) +DEFINE_CAST_KERNEL(cast_kernel_bf16_to_f32, bfloat16_t, float) + +#undef DEFINE_CAST_KERNEL + +extern "C" infiniStatus_t rms_norm_cast_w_launch( + void *dst, const void *src, + infiniDtype_t src_dtype, infiniDtype_t dst_dtype, + size_t count, void *stream) { + + if (dst_dtype != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#define LAUNCH_CAST(DTYPE_ENUM, KERNEL_NAME) \ + case DTYPE_ENUM: \ + KERNEL_NAME<<<1, nullptr, stream>>>( \ + dst, (GM_ADDR)src, count); \ + return INFINI_STATUS_SUCCESS; + + switch (src_dtype) { + LAUNCH_CAST(INFINI_DTYPE_F16, cast_kernel_f16_to_f32) + LAUNCH_CAST(INFINI_DTYPE_BF16, cast_kernel_bf16_to_f32) + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_CAST +} diff --git a/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc b/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc index d5d2649a4..2e9eb8a25 100644 --- a/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc +++ b/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc @@ -1,6 +1,12 @@ #include "rms_norm_aclnn.h" #include "../../../devices/ascend/common_ascend.h" #include +#include + +extern "C" infiniStatus_t rms_norm_cast_w_launch( + void *dst, const void *src, + infiniDtype_t src_dtype, infiniDtype_t dst_dtype, + size_t count, void *stream); namespace op::rms_norm::ascend { @@ -11,13 +17,24 @@ struct Descriptor::Opaque { aclnnTensorDescriptor_t rstd; size_t workspaceSize; aclOpExecutor *executor; + bool needs_cast_w; + size_t cast_w_offset; + size_t w_padded_offset; + size_t w_padded_size; + + Opaque(aclnnTensorDescriptor_t y_, aclnnTensorDescriptor_t x_, + aclnnTensorDescriptor_t w_, aclnnTensorDescriptor_t rstd_, + size_t ws, aclOpExecutor *exec, + bool cast_w, size_t cast_off, size_t pad_off, size_t pad_sz) + : y(y_), x(x_), w(w_), rstd(rstd_), workspaceSize(ws), executor(exec), + needs_cast_w(cast_w), cast_w_offset(cast_off), + w_padded_offset(pad_off), w_padded_size(pad_sz) {} ~Opaque() { delete y; delete x; delete w; delete rstd; - aclDestroyAclOpExecutor(executor); } }; @@ -38,42 +55,72 @@ infiniStatus_t Descriptor::create( CHECK_RESULT(result); auto info = result.take(); - size_t workspace_size = 0; - aclOpExecutor *executor = nullptr; - aclnnTensorDescriptor_t y = nullptr; - aclnnTensorDescriptor_t x = nullptr; - aclnnTensorDescriptor_t w = nullptr; - aclnnTensorDescriptor_t rstd = nullptr; + auto handle_ascend = reinterpret_cast(handle); - std::vector slice_shape = {static_cast((info.shape)[1])}; + std::vector slice_shape = {static_cast(info.dim())}; auto slice_stride = std::vector(1, 1); - y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); - x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); - w = new aclnnTensorDescriptor(w_desc); - - // Get AclTensor - aclTensor *ty = y->tensor; - aclTensor *tx = x->tensor; - aclTensor *tw = w->tensor; - // Set rstdDesc - // See: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnRmsNorm.md - // rstdTensor cannot set nullptr in aclnn + + aclnnTensorDescriptor_t y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); + aclnnTensorDescriptor_t x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); + + // 仅在跨半精度组合时需要将 w cast 到 atype + // (F16 atype + BF16 w, 或 BF16 atype + F16 w) + bool needs_cast_w = (info.atype != info.wtype && info.wtype != INFINI_DTYPE_F32); + aclnnTensorDescriptor_t w = nullptr; + std::vector w_shape_i64_dbg; + std::vector w_strides_i64_dbg; + if (needs_cast_w) { + // 规避 constructor #2 的 ndim 内存 corruption 问题 + // 先用 constructor #1 从 w_desc 正确构造,再替换 tensor 为正确的 dtype + w = new aclnnTensorDescriptor(w_desc); + w_shape_i64_dbg = w->shape; + w_strides_i64_dbg = w->strides; + if (w->tensor) { + aclDestroyTensor(w->tensor); + } + w->dataType = toAclDataType(INFINI_DTYPE_F32); + w->tensor = aclCreateTensor(w->shape.data(), w->ndim, w->dataType, + w->strides.data(), w->offset, w->format, + w->storageShape.data(), w->storageNdim, nullptr); + } else { + w = new aclnnTensorDescriptor(w_desc); + } + auto rstd_shape = std::vector(1, 1); auto rstd_strides = std::vector(1, 1); - rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides); - aclTensor *trstd = rstd->tensor; + aclnnTensorDescriptor_t rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides); + + size_t workspace_size = 0; + aclOpExecutor *executor = nullptr; + + CHECK_ACL(aclnnRmsNormGetWorkspaceSize( + x->tensor, + w->tensor, + static_cast(epsilon), + y->tensor, + rstd->tensor, + &workspace_size, + &executor)); - // Get WorkspaceSize and set executor - CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast(epsilon), ty, trstd, &workspace_size, &executor)); aclSetAclOpExecutorRepeatable(executor); - auto handle_ascend = reinterpret_cast(handle); - size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType); + size_t rstd_size = rstd->numel() * aclDataTypeSize(rstd->dataType); + size_t cast_w_dst_size = needs_cast_w ? info.dim() * sizeof(float) : 0; + size_t w_padded_size = 0; + if (needs_cast_w) { + size_t w_raw_bytes = info.dim() * infiniSizeOf(info.wtype); + w_padded_size = ((w_raw_bytes + 31) / 32) * 32; + } + size_t all_workspace_size = workspace_size + rstd_size + cast_w_dst_size + w_padded_size; + size_t cast_w_offset = workspace_size + rstd_size; + size_t w_padded_offset = cast_w_offset + cast_w_dst_size; + *desc_ptr = new Descriptor( - new Opaque{y, x, w, rstd, workspace_size, executor}, + new Opaque{y, x, w, rstd, workspace_size, executor, needs_cast_w, cast_w_offset, w_padded_offset, w_padded_size}, std::move(info), all_workspace_size, - handle_ascend->device, handle_ascend->device_id); + handle_ascend->device, + handle_ascend->device_id); return INFINI_STATUS_SUCCESS; } @@ -86,21 +133,52 @@ infiniStatus_t Descriptor::calculate( if (workspace_size < workspaceSize()) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } + auto tw = _opaque->w->tensor; auto tx = _opaque->x->tensor; auto ty = _opaque->y->tensor; auto trstd = _opaque->rstd->tensor; void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize); + void *w_ptr = nullptr; + + if (_opaque->needs_cast_w) { + void *cast_w_ptr = (void *)((uint8_t *)workspace + _opaque->cast_w_offset); + void *w_padded_src = (void *)((uint8_t *)workspace + _opaque->w_padded_offset); + size_t w_bytes = _info.dim() * infiniSizeOf(_info.wtype); + aclrtMemcpyAsync(w_padded_src, _opaque->w_padded_size, (void *)w, w_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, (aclrtStream)stream); + rms_norm_cast_w_launch(cast_w_ptr, w_padded_src, _info.wtype, INFINI_DTYPE_F32, _info.dim(), stream); + w_ptr = cast_w_ptr; + } else { + w_ptr = (void *)w; + } auto unit = infiniSizeOf(_info.atype); - AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w); + + AclSetTensorAddr(_opaque->executor, 1, tw, w_ptr); AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr); - for (size_t i = 0; i < (_info.shape)[0]; ++i) { - AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit); - AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit); - CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream)); + + auto ndim = _info.ndim(); + size_t outer = ndim == 2 ? 1 : _info.shape[0]; + size_t inner = ndim == 2 ? _info.shape[0] : _info.shape[1]; + + for (size_t b = 0; b < outer; ++b) { + for (size_t s = 0; s < inner; ++s) { + ptrdiff_t x_offset, y_offset; + if (ndim == 2) { + x_offset = s * _info.x_strides[0]; + y_offset = s * _info.y_strides[0]; + } else { + x_offset = b * _info.x_strides[0] + s * _info.x_strides[1]; + y_offset = b * _info.y_strides[0] + s * _info.y_strides[1]; + } + AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + x_offset * unit); + AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + y_offset * unit); + CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream)); + } } + return INFINI_STATUS_SUCCESS; }