diff --git a/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc b/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc index d5d2649a4..c9532ac9c 100644 --- a/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc +++ b/src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc @@ -1,16 +1,50 @@ #include "rms_norm_aclnn.h" +#include "../../../../utils/custom_types.h" #include "../../../devices/ascend/common_ascend.h" #include +#include +#include namespace op::rms_norm::ascend { +namespace { + +size_t alignOffset(size_t offset, size_t alignment) { + return (offset + alignment - 1) / alignment * alignment; +} + +bool needsWeightCast(infiniDtype_t atype, infiniDtype_t wtype) { + return (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) + && (wtype == INFINI_DTYPE_F16 || wtype == INFINI_DTYPE_BF16) + && atype != wtype; +} + +void castWeightToFloat(float *dst, const void *src, size_t count, infiniDtype_t wtype) { + if (wtype == INFINI_DTYPE_F16) { + auto src_t = reinterpret_cast(src); + for (size_t i = 0; i < count; ++i) { + dst[i] = utils::cast(src_t[i]); + } + } else if (wtype == INFINI_DTYPE_BF16) { + auto src_t = reinterpret_cast(src); + for (size_t i = 0; i < count; ++i) { + dst[i] = utils::cast(src_t[i]); + } + } +} + +} // namespace + struct Descriptor::Opaque { aclnnTensorDescriptor_t y; aclnnTensorDescriptor_t x; aclnnTensorDescriptor_t w; aclnnTensorDescriptor_t rstd; size_t workspaceSize; + size_t rstdOffset; + void *weightAddr; aclOpExecutor *executor; + bool cast_weight; ~Opaque() { delete y; @@ -18,6 +52,7 @@ struct Descriptor::Opaque { delete w; delete rstd; + aclrtFree(weightAddr); aclDestroyAclOpExecutor(executor); } }; @@ -44,12 +79,16 @@ infiniStatus_t Descriptor::create( aclnnTensorDescriptor_t x = nullptr; aclnnTensorDescriptor_t w = nullptr; aclnnTensorDescriptor_t rstd = nullptr; + void *weight_addr = nullptr; - std::vector slice_shape = {static_cast((info.shape)[1])}; + std::vector slice_shape = {static_cast(info.dim())}; auto slice_stride = std::vector(1, 1); y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); - w = new aclnnTensorDescriptor(w_desc); + auto cast_weight = needsWeightCast(info.atype, info.wtype); + w = cast_weight + ? new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), slice_shape, slice_stride) + : new aclnnTensorDescriptor(w_desc); // Get AclTensor aclTensor *ty = y->tensor; @@ -68,9 +107,15 @@ infiniStatus_t Descriptor::create( aclSetAclOpExecutorRepeatable(executor); auto handle_ascend = reinterpret_cast(handle); - size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType); + auto rstd_size = rstd->numel() * aclDataTypeSize(rstd->dataType); + auto rstd_offset = alignOffset(workspace_size, 32); + auto weight_workspace = cast_weight ? info.dim() * infiniSizeOf(INFINI_DTYPE_F32) : 0; + if (cast_weight) { + CHECK_ACL(aclrtMalloc(&weight_addr, weight_workspace, ACL_MEM_MALLOC_HUGE_FIRST)); + } + size_t all_workspace_size = rstd_offset + rstd_size; *desc_ptr = new Descriptor( - new Opaque{y, x, w, rstd, workspace_size, executor}, + new Opaque{y, x, w, rstd, workspace_size, rstd_offset, weight_addr, executor, cast_weight}, std::move(info), all_workspace_size, handle_ascend->device, handle_ascend->device_id); @@ -91,14 +136,35 @@ infiniStatus_t Descriptor::calculate( auto ty = _opaque->y->tensor; auto trstd = _opaque->rstd->tensor; - void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize); - + void *rstdPtr = static_cast(static_cast(workspace) + _opaque->rstdOffset); auto unit = infiniSizeOf(_info.atype); - AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w); + void *weightPtr = const_cast(w); + + if (_opaque->cast_weight) { + auto weightBytesIn = _info.dim() * infiniSizeOf(_info.wtype); + auto weightBytesOut = _info.dim() * infiniSizeOf(INFINI_DTYPE_F32); + std::vector hostWeightIn(weightBytesIn); + std::vector hostWeightOut(_info.dim()); + CHECK_ACL(aclrtMemcpy(hostWeightIn.data(), weightBytesIn, w, weightBytesIn, ACL_MEMCPY_DEVICE_TO_HOST)); + castWeightToFloat(hostWeightOut.data(), hostWeightIn.data(), _info.dim(), _info.wtype); + weightPtr = _opaque->weightAddr; + CHECK_ACL(aclrtMemcpy(weightPtr, weightBytesOut, hostWeightOut.data(), weightBytesOut, ACL_MEMCPY_HOST_TO_DEVICE)); + } + + size_t outer = 1; + for (size_t i = 0; i + 1 < _info.ndim(); ++i) { + outer *= _info.shape[i]; + } + + AclSetTensorAddr(_opaque->executor, 1, tw, weightPtr); AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr); - for (size_t i = 0; i < (_info.shape)[0]; ++i) { - AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit); - AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit); + for (size_t i = 0; i < outer; ++i) { + size_t batch = _info.ndim() == 3 ? i / _info.shape[1] : i; + size_t head = _info.ndim() == 3 ? i % _info.shape[1] : 0; + auto x_offset = batch * _info.x_strides[0] + (_info.ndim() == 3 ? head * _info.x_strides[1] : 0); + auto y_offset = batch * _info.y_strides[0] + (_info.ndim() == 3 ? head * _info.y_strides[1] : 0); + AclSetTensorAddr(_opaque->executor, 0, tx, const_cast(static_cast(x) + x_offset * unit)); + AclSetTensorAddr(_opaque->executor, 2, ty, static_cast(y) + y_offset * unit); CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream)); } return INFINI_STATUS_SUCCESS;