Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/infiniop/devices/ascend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
../../ops/rope/ascend/rope_ascend_kernel.cpp
../../ops/random_sample/ascend/random_sample_kernel.cpp
../../ops/rms_norm/ascend/cast_kernel.cpp
)

target_include_directories(ascend_kernels PRIVATE ../../../../include)
108 changes: 108 additions & 0 deletions src/infiniop/ops/rms_norm/ascend/cast_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include "../../../devices/ascend/ascend_kernel_common.h"

using namespace AscendC;

template <typename SrcT, typename DstT>
class CastKernel {
public:
__aicore__ inline CastKernel() {}
__aicore__ inline void init(GM_ADDR dst, GM_ADDR src, size_t count);
__aicore__ inline void process();

private:
__aicore__ inline void copyIn();
__aicore__ inline void compute();
__aicore__ inline void copyOut();

GlobalTensor<SrcT> _src_gm;
GlobalTensor<DstT> _dst_gm;
TQue<QuePosition::VECIN, BUFFER_NUM> _in_queue;
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_queue;
TPipe _pipe;
size_t _tile_len, _copy_len;
};

template <typename SrcT, typename DstT>
__aicore__ inline void CastKernel<SrcT, DstT>::init(GM_ADDR dst, GM_ADDR src, size_t count) {
_tile_len = count;
_copy_len = alignTileLen<SrcT>(_tile_len, BYTE_ALIGN);

_dst_gm.SetGlobalBuffer((__gm__ DstT *)dst);
_src_gm.SetGlobalBuffer((__gm__ SrcT *)src);

_pipe.InitBuffer(_in_queue, BUFFER_NUM, _copy_len * sizeof(SrcT));
_pipe.InitBuffer(_out_queue, BUFFER_NUM, _copy_len * sizeof(DstT));
}

template <typename SrcT, typename DstT>
__aicore__ inline void CastKernel<SrcT, DstT>::copyIn() {
LocalTensor<SrcT> srcLocal = _in_queue.AllocTensor<SrcT>();
DataCopy(srcLocal, _src_gm, _copy_len);
_in_queue.EnQue(srcLocal);
}

template <typename SrcT, typename DstT>
__aicore__ inline void CastKernel<SrcT, DstT>::compute() {
LocalTensor<SrcT> srcLocal = _in_queue.DeQue<SrcT>();
LocalTensor<DstT> dstLocal = _out_queue.AllocTensor<DstT>();
Cast(dstLocal, srcLocal, RoundMode::CAST_NONE, _copy_len);
_out_queue.EnQue<DstT>(dstLocal);
_in_queue.FreeTensor(srcLocal);
}

template <typename SrcT, typename DstT>
__aicore__ inline void CastKernel<SrcT, DstT>::copyOut() {
LocalTensor<DstT> dstLocal = _out_queue.DeQue<DstT>();
if (_tile_len * sizeof(DstT) % BYTE_ALIGN != 0) {
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(DstT)), 0, 0, 0};
DataCopyPad(_dst_gm, dstLocal, dcep);
} else {
DataCopy(_dst_gm, dstLocal, _tile_len);
}
_out_queue.FreeTensor(dstLocal);
}

template <typename SrcT, typename DstT>
__aicore__ inline void CastKernel<SrcT, DstT>::process() {
copyIn();
compute();
copyOut();
}

#define DEFINE_CAST_KERNEL(KERNEL_NAME, SRC_T, DST_T) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR dst, GM_ADDR src, \
size_t count) { \
CastKernel<SRC_T, DST_T> op; \
op.init(dst, src, count); \
op.process(); \
}

DEFINE_CAST_KERNEL(cast_kernel_f16_to_f32, half, float)
DEFINE_CAST_KERNEL(cast_kernel_bf16_to_f32, bfloat16_t, float)

#undef DEFINE_CAST_KERNEL

extern "C" infiniStatus_t rms_norm_cast_w_launch(
void *dst, const void *src,
infiniDtype_t src_dtype, infiniDtype_t dst_dtype,
size_t count, void *stream) {

if (dst_dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

#define LAUNCH_CAST(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<1, nullptr, stream>>>( \
dst, (GM_ADDR)src, count); \
return INFINI_STATUS_SUCCESS;

switch (src_dtype) {
LAUNCH_CAST(INFINI_DTYPE_F16, cast_kernel_f16_to_f32)
LAUNCH_CAST(INFINI_DTYPE_BF16, cast_kernel_bf16_to_f32)
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

#undef LAUNCH_CAST
}
142 changes: 110 additions & 32 deletions src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#include "rms_norm_aclnn.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_rms_norm.h>
#include <cstdio>

extern "C" infiniStatus_t rms_norm_cast_w_launch(
void *dst, const void *src,
infiniDtype_t src_dtype, infiniDtype_t dst_dtype,
size_t count, void *stream);

namespace op::rms_norm::ascend {

Expand All @@ -11,13 +17,24 @@ struct Descriptor::Opaque {
aclnnTensorDescriptor_t rstd;
size_t workspaceSize;
aclOpExecutor *executor;
bool needs_cast_w;
size_t cast_w_offset;
size_t w_padded_offset;
size_t w_padded_size;

Opaque(aclnnTensorDescriptor_t y_, aclnnTensorDescriptor_t x_,
aclnnTensorDescriptor_t w_, aclnnTensorDescriptor_t rstd_,
size_t ws, aclOpExecutor *exec,
bool cast_w, size_t cast_off, size_t pad_off, size_t pad_sz)
: y(y_), x(x_), w(w_), rstd(rstd_), workspaceSize(ws), executor(exec),
needs_cast_w(cast_w), cast_w_offset(cast_off),
w_padded_offset(pad_off), w_padded_size(pad_sz) {}

~Opaque() {
delete y;
delete x;
delete w;
delete rstd;

aclDestroyAclOpExecutor(executor);
}
};
Expand All @@ -38,42 +55,72 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
auto info = result.take();

size_t workspace_size = 0;
aclOpExecutor *executor = nullptr;
aclnnTensorDescriptor_t y = nullptr;
aclnnTensorDescriptor_t x = nullptr;
aclnnTensorDescriptor_t w = nullptr;
aclnnTensorDescriptor_t rstd = nullptr;
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);

std::vector<int64_t> slice_shape = {static_cast<int64_t>((info.shape)[1])};
std::vector<int64_t> slice_shape = {static_cast<int64_t>(info.dim())};
auto slice_stride = std::vector<int64_t>(1, 1);
y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
w = new aclnnTensorDescriptor(w_desc);

// Get AclTensor
aclTensor *ty = y->tensor;
aclTensor *tx = x->tensor;
aclTensor *tw = w->tensor;
// Set rstdDesc
// See: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnRmsNorm.md
// rstdTensor cannot set nullptr in aclnn

aclnnTensorDescriptor_t y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
aclnnTensorDescriptor_t x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);

// 仅在跨半精度组合时需要将 w cast 到 atype
// (F16 atype + BF16 w, 或 BF16 atype + F16 w)
bool needs_cast_w = (info.atype != info.wtype && info.wtype != INFINI_DTYPE_F32);
aclnnTensorDescriptor_t w = nullptr;
std::vector<int64_t> w_shape_i64_dbg;
std::vector<int64_t> w_strides_i64_dbg;
if (needs_cast_w) {
// 规避 constructor #2 的 ndim 内存 corruption 问题
// 先用 constructor #1 从 w_desc 正确构造,再替换 tensor 为正确的 dtype
w = new aclnnTensorDescriptor(w_desc);
w_shape_i64_dbg = w->shape;
w_strides_i64_dbg = w->strides;
if (w->tensor) {
aclDestroyTensor(w->tensor);
}
w->dataType = toAclDataType(INFINI_DTYPE_F32);
w->tensor = aclCreateTensor(w->shape.data(), w->ndim, w->dataType,
w->strides.data(), w->offset, w->format,
w->storageShape.data(), w->storageNdim, nullptr);
} else {
w = new aclnnTensorDescriptor(w_desc);
}

auto rstd_shape = std::vector<int64_t>(1, 1);
auto rstd_strides = std::vector<int64_t>(1, 1);
rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides);
aclTensor *trstd = rstd->tensor;
aclnnTensorDescriptor_t rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides);

size_t workspace_size = 0;
aclOpExecutor *executor = nullptr;

CHECK_ACL(aclnnRmsNormGetWorkspaceSize(
x->tensor,
w->tensor,
static_cast<double>(epsilon),
y->tensor,
rstd->tensor,
&workspace_size,
&executor));

// Get WorkspaceSize and set executor
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor);

auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
size_t rstd_size = rstd->numel() * aclDataTypeSize(rstd->dataType);
size_t cast_w_dst_size = needs_cast_w ? info.dim() * sizeof(float) : 0;
size_t w_padded_size = 0;
if (needs_cast_w) {
size_t w_raw_bytes = info.dim() * infiniSizeOf(info.wtype);
w_padded_size = ((w_raw_bytes + 31) / 32) * 32;
}
size_t all_workspace_size = workspace_size + rstd_size + cast_w_dst_size + w_padded_size;
size_t cast_w_offset = workspace_size + rstd_size;
size_t w_padded_offset = cast_w_offset + cast_w_dst_size;

*desc_ptr = new Descriptor(
new Opaque{y, x, w, rstd, workspace_size, executor},
new Opaque{y, x, w, rstd, workspace_size, executor, needs_cast_w, cast_w_offset, w_padded_offset, w_padded_size},
std::move(info),
all_workspace_size,
handle_ascend->device, handle_ascend->device_id);
handle_ascend->device,
handle_ascend->device_id);

return INFINI_STATUS_SUCCESS;
}
Expand All @@ -86,21 +133,52 @@ infiniStatus_t Descriptor::calculate(
if (workspace_size < workspaceSize()) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

auto tw = _opaque->w->tensor;
auto tx = _opaque->x->tensor;
auto ty = _opaque->y->tensor;
auto trstd = _opaque->rstd->tensor;

void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);
void *w_ptr = nullptr;

if (_opaque->needs_cast_w) {
void *cast_w_ptr = (void *)((uint8_t *)workspace + _opaque->cast_w_offset);
void *w_padded_src = (void *)((uint8_t *)workspace + _opaque->w_padded_offset);
size_t w_bytes = _info.dim() * infiniSizeOf(_info.wtype);
aclrtMemcpyAsync(w_padded_src, _opaque->w_padded_size, (void *)w, w_bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE, (aclrtStream)stream);
rms_norm_cast_w_launch(cast_w_ptr, w_padded_src, _info.wtype, INFINI_DTYPE_F32, _info.dim(), stream);
w_ptr = cast_w_ptr;
} else {
w_ptr = (void *)w;
}

auto unit = infiniSizeOf(_info.atype);
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);

AclSetTensorAddr(_opaque->executor, 1, tw, w_ptr);
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
for (size_t i = 0; i < (_info.shape)[0]; ++i) {
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));

auto ndim = _info.ndim();
size_t outer = ndim == 2 ? 1 : _info.shape[0];
size_t inner = ndim == 2 ? _info.shape[0] : _info.shape[1];

for (size_t b = 0; b < outer; ++b) {
for (size_t s = 0; s < inner; ++s) {
ptrdiff_t x_offset, y_offset;
if (ndim == 2) {
x_offset = s * _info.x_strides[0];
y_offset = s * _info.y_strides[0];
} else {
x_offset = b * _info.x_strides[0] + s * _info.x_strides[1];
y_offset = b * _info.y_strides[0] + s * _info.y_strides[1];
}
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + x_offset * unit);
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + y_offset * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
}
}

return INFINI_STATUS_SUCCESS;
}

Expand Down
Loading