Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc
Original file line number Diff line number Diff line change
@@ -1,23 +1,58 @@
#include "rms_norm_aclnn.h"
#include "../../../../utils/custom_types.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_rms_norm.h>
#include <cstdint>
#include <vector>

namespace op::rms_norm::ascend {

namespace {

size_t alignOffset(size_t offset, size_t alignment) {
return (offset + alignment - 1) / alignment * alignment;
}

bool needsWeightCast(infiniDtype_t atype, infiniDtype_t wtype) {
return (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16)
&& (wtype == INFINI_DTYPE_F16 || wtype == INFINI_DTYPE_BF16)
&& atype != wtype;
}

void castWeightToFloat(float *dst, const void *src, size_t count, infiniDtype_t wtype) {
if (wtype == INFINI_DTYPE_F16) {
auto src_t = reinterpret_cast<const fp16_t *>(src);
for (size_t i = 0; i < count; ++i) {
dst[i] = utils::cast<float>(src_t[i]);
}
} else if (wtype == INFINI_DTYPE_BF16) {
auto src_t = reinterpret_cast<const bf16_t *>(src);
for (size_t i = 0; i < count; ++i) {
dst[i] = utils::cast<float>(src_t[i]);
}
}
}

} // namespace

struct Descriptor::Opaque {
aclnnTensorDescriptor_t y;
aclnnTensorDescriptor_t x;
aclnnTensorDescriptor_t w;
aclnnTensorDescriptor_t rstd;
size_t workspaceSize;
size_t rstdOffset;
void *weightAddr;
aclOpExecutor *executor;
bool cast_weight;

~Opaque() {
delete y;
delete x;
delete w;
delete rstd;

aclrtFree(weightAddr);
aclDestroyAclOpExecutor(executor);
}
};
Expand All @@ -44,12 +79,16 @@ infiniStatus_t Descriptor::create(
aclnnTensorDescriptor_t x = nullptr;
aclnnTensorDescriptor_t w = nullptr;
aclnnTensorDescriptor_t rstd = nullptr;
void *weight_addr = nullptr;

std::vector<int64_t> slice_shape = {static_cast<int64_t>((info.shape)[1])};
std::vector<int64_t> slice_shape = {static_cast<int64_t>(info.dim())};
auto slice_stride = std::vector<int64_t>(1, 1);
y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
w = new aclnnTensorDescriptor(w_desc);
auto cast_weight = needsWeightCast(info.atype, info.wtype);
w = cast_weight
? new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), slice_shape, slice_stride)
: new aclnnTensorDescriptor(w_desc);

// Get AclTensor
aclTensor *ty = y->tensor;
Expand All @@ -68,9 +107,15 @@ infiniStatus_t Descriptor::create(
aclSetAclOpExecutorRepeatable(executor);

auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
auto rstd_size = rstd->numel() * aclDataTypeSize(rstd->dataType);
auto rstd_offset = alignOffset(workspace_size, 32);
auto weight_workspace = cast_weight ? info.dim() * infiniSizeOf(INFINI_DTYPE_F32) : 0;
if (cast_weight) {
CHECK_ACL(aclrtMalloc(&weight_addr, weight_workspace, ACL_MEM_MALLOC_HUGE_FIRST));
}
size_t all_workspace_size = rstd_offset + rstd_size;
*desc_ptr = new Descriptor(
new Opaque{y, x, w, rstd, workspace_size, executor},
new Opaque{y, x, w, rstd, workspace_size, rstd_offset, weight_addr, executor, cast_weight},
std::move(info),
all_workspace_size,
handle_ascend->device, handle_ascend->device_id);
Expand All @@ -91,14 +136,35 @@ infiniStatus_t Descriptor::calculate(
auto ty = _opaque->y->tensor;
auto trstd = _opaque->rstd->tensor;

void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);

void *rstdPtr = static_cast<void *>(static_cast<uint8_t *>(workspace) + _opaque->rstdOffset);
auto unit = infiniSizeOf(_info.atype);
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
void *weightPtr = const_cast<void *>(w);

if (_opaque->cast_weight) {
auto weightBytesIn = _info.dim() * infiniSizeOf(_info.wtype);
auto weightBytesOut = _info.dim() * infiniSizeOf(INFINI_DTYPE_F32);
std::vector<uint8_t> hostWeightIn(weightBytesIn);
std::vector<float> hostWeightOut(_info.dim());
CHECK_ACL(aclrtMemcpy(hostWeightIn.data(), weightBytesIn, w, weightBytesIn, ACL_MEMCPY_DEVICE_TO_HOST));
castWeightToFloat(hostWeightOut.data(), hostWeightIn.data(), _info.dim(), _info.wtype);
weightPtr = _opaque->weightAddr;
CHECK_ACL(aclrtMemcpy(weightPtr, weightBytesOut, hostWeightOut.data(), weightBytesOut, ACL_MEMCPY_HOST_TO_DEVICE));
}

size_t outer = 1;
for (size_t i = 0; i + 1 < _info.ndim(); ++i) {
outer *= _info.shape[i];
}

AclSetTensorAddr(_opaque->executor, 1, tw, weightPtr);
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
for (size_t i = 0; i < (_info.shape)[0]; ++i) {
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
for (size_t i = 0; i < outer; ++i) {
size_t batch = _info.ndim() == 3 ? i / _info.shape[1] : i;
size_t head = _info.ndim() == 3 ? i % _info.shape[1] : 0;
auto x_offset = batch * _info.x_strides[0] + (_info.ndim() == 3 ? head * _info.x_strides[1] : 0);
auto y_offset = batch * _info.y_strides[0] + (_info.ndim() == 3 ? head * _info.y_strides[1] : 0);
AclSetTensorAddr(_opaque->executor, 0, tx, const_cast<char *>(static_cast<const char *>(x) + x_offset * unit));
AclSetTensorAddr(_opaque->executor, 2, ty, static_cast<char *>(y) + y_offset * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
}
return INFINI_STATUS_SUCCESS;
Expand Down
Loading