From eddd57705b08f99568525d92b40499e902abb51f Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 12 Jun 2026 19:35:48 +0800 Subject: [PATCH] issue/1276 - support bf16 in ascend causal softmax --- .../ascend/causal_softmax_ascend.cc | 118 +++++++++++++++--- 1 file changed, 104 insertions(+), 14 deletions(-) diff --git a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc index b37557da7..3b183ae0f 100644 --- a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc +++ b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc @@ -1,31 +1,58 @@ #include "causal_softmax_ascend.h" #include "../../../devices/ascend/common_ascend.h" +#include #include #include +#include namespace op::causal_softmax::ascend { +namespace { + +bool isCompact(const CausalSoftmaxInfo &info, ptrdiff_t stride_b, ptrdiff_t stride_i, ptrdiff_t stride_j) { + return stride_j == 1 + && stride_i == static_cast(info.total_seq_len) + && (info.batch_size == 1 || stride_b == static_cast(info.seq_len * info.total_seq_len)); +} + +} // namespace + struct Descriptor::Opaque { aclnnTensorDescriptor_t x; aclnnTensorDescriptor_t mask; aclnnTensorDescriptor_t y; aclnnTensorDescriptor_t value; + aclnnTensorDescriptor_t temp_x; + aclnnTensorDescriptor_t temp_y; void *mask_addr; void *value_addr; - uint64_t workspacesize; + void *temp_x_addr; + void *temp_y_addr; + size_t workspacesize; aclOpExecutor *executor; + aclOpExecutor *temp_executor; + aclOpExecutor *copy_in_executor; + aclOpExecutor *copy_out_executor; + bool use_temp; ~Opaque() { delete x; delete mask; delete y; delete value; + delete temp_x; + delete temp_y; aclrtFree(mask_addr); aclrtFree(value_addr); + aclrtFree(temp_x_addr); + aclrtFree(temp_y_addr); // Delete useless executor aclDestroyAclOpExecutor(executor); + aclDestroyAclOpExecutor(temp_executor); + aclDestroyAclOpExecutor(copy_in_executor); + aclDestroyAclOpExecutor(copy_out_executor); } }; @@ -44,25 +71,38 @@ infiniStatus_t Descriptor::create( CausalSoftmaxInfo info = result.take(); aclOpExecutor *executor = nullptr; + aclOpExecutor *temp_executor = nullptr; aclOpExecutor *mask_executor = nullptr; + aclOpExecutor *copy_in_executor = nullptr; + aclOpExecutor *copy_out_executor = nullptr; aclnnTensorDescriptor_t y = nullptr; aclnnTensorDescriptor_t mask = nullptr; aclnnTensorDescriptor_t x = nullptr; aclnnTensorDescriptor_t value = nullptr; + aclnnTensorDescriptor_t temp_x = nullptr; + aclnnTensorDescriptor_t temp_y = nullptr; void *mask_addr = nullptr; void *value_addr = nullptr; + void *temp_x_addr = nullptr; + void *temp_y_addr = nullptr; size_t workspacesize_softmax = 0; + size_t workspacesize_temp_softmax = 0; size_t workspacesize_mask = 0; + size_t workspacesize_copy_in = 0; + size_t workspacesize_copy_out = 0; - // Create Aclnn Tensor Descriptors for input , mask and output + // Create Aclnn Tensor Descriptors for input, mask and output std::vector shape = {static_cast(info.batch_size), static_cast(info.seq_len), static_cast(info.total_seq_len)}; std::vector x_strides = {static_cast(info.x_stride_b), static_cast(info.x_stride_i), static_cast(info.x_stride_j)}; std::vector y_strides = {static_cast(info.y_stride_b), static_cast(info.y_stride_i), static_cast(info.y_stride_j)}; + std::vector compact_strides = {static_cast(info.seq_len * info.total_seq_len), static_cast(info.total_seq_len), 1}; y = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, y_strides); x = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, x_strides); + temp_x = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, compact_strides); + temp_y = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, compact_strides); mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast(info.seq_len), static_cast(info.total_seq_len)}, {static_cast(info.total_seq_len), 1}); - // Initialize the value tensor with -∞ + // Initialize the value tensor with -inf if (info.dtype == INFINI_DTYPE_F16) { uint16_t mask_value = 0xfc00; auto size = aclDataTypeSize(aclDataType::ACL_FLOAT16); @@ -93,21 +133,42 @@ infiniStatus_t Descriptor::create( // Get the workspace size for the op aclTensor *tx = x->tensor; aclTensor *ty = y->tensor; + aclTensor *ttemp_x = temp_x->tensor; + aclTensor *ttemp_y = temp_y->tensor; aclTensor *tmask = mask->tensor; aclTensor *tvalue = value->tensor; - CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); - - int64_t dim = 2; - CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor)); - // set executor reusable - aclSetAclOpExecutorRepeatable(executor); + bool use_temp = !isCompact(info, info.x_stride_b, info.x_stride_i, info.x_stride_j) + || !isCompact(info, info.y_stride_b, info.y_stride_i, info.y_stride_j); + + if (use_temp) { + CHECK_ACL(aclnnInplaceCopyGetWorkspaceSize(ttemp_x, tx, &workspacesize_copy_in, ©_in_executor)); + aclSetAclOpExecutorRepeatable(copy_in_executor); + CHECK_ACL(aclnnInplaceCopyGetWorkspaceSize(ty, ttemp_y, &workspacesize_copy_out, ©_out_executor)); + aclSetAclOpExecutorRepeatable(copy_out_executor); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp_x, tmask, tvalue, &workspacesize_mask, &mask_executor)); + int64_t dim = 2; + CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(ttemp_x, dim, ttemp_y, &workspacesize_temp_softmax, &temp_executor)); + aclSetAclOpExecutorRepeatable(temp_executor); + } else { + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); + int64_t dim = 2; + CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor)); + // set executor reusable + aclSetAclOpExecutorRepeatable(executor); + } - // Create the descripto - size_t all_workspacesize = std::max(workspacesize_softmax, workspacesize_mask); + size_t op_workspace_size = std::max(std::max(workspacesize_softmax, workspacesize_temp_softmax), + std::max(workspacesize_mask, std::max(workspacesize_copy_in, workspacesize_copy_out))); + size_t all_workspacesize = op_workspace_size; + if (use_temp) { + size_t temp_bytes = temp_x->numel() * infiniSizeOf(info.dtype); + CHECK_ACL(aclrtMalloc(&temp_x_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + CHECK_ACL(aclrtMalloc(&temp_y_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + } - *desc_ptr = new Descriptor(new Opaque{x, mask, y, value, mask_addr, value_addr, - workspacesize_softmax, executor}, + *desc_ptr = new Descriptor(new Opaque{x, mask, y, value, temp_x, temp_y, mask_addr, value_addr, + temp_x_addr, temp_y_addr, op_workspace_size, executor, temp_executor, copy_in_executor, copy_out_executor, use_temp}, std::move(info), all_workspacesize, handle_ascend->device, handle_ascend->device_id); return INFINI_STATUS_SUCCESS; @@ -121,6 +182,35 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi auto ty = _opaque->y->tensor; auto tmask = _opaque->mask->tensor; auto tvalue = _opaque->value->tensor; + + if (_opaque->use_temp) { + auto ttemp_x = _opaque->temp_x->tensor; + auto ttemp_y = _opaque->temp_y->tensor; + void *temp_x = _opaque->temp_x_addr; + void *temp_y = _opaque->temp_y_addr; + + AclSetTensorAddr(_opaque->copy_in_executor, 0, ttemp_x, temp_x); + AclSetTensorAddr(_opaque->copy_in_executor, 1, tx, (void *)x); + CHECK_ACL(aclnnInplaceCopy(workspace, _opaque->workspacesize, _opaque->copy_in_executor, stream)); + + aclOpExecutor *mask_executor = nullptr; + size_t workspacesize_mask = 0; + AclSetTensorAddr(mask_executor, 0, ttemp_x, temp_x); + AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr); + AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp_x, tmask, tvalue, &workspacesize_mask, &mask_executor)); + CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, _opaque->workspacesize, mask_executor, stream)); + + AclSetTensorAddr(_opaque->temp_executor, 0, ttemp_x, temp_x); + AclSetTensorAddr(_opaque->temp_executor, 1, ttemp_y, temp_y); + CHECK_ACL(aclnnSoftmax(workspace, _opaque->workspacesize, _opaque->temp_executor, stream)); + + AclSetTensorAddr(_opaque->copy_out_executor, 0, ty, y); + AclSetTensorAddr(_opaque->copy_out_executor, 1, ttemp_y, temp_y); + CHECK_ACL(aclnnInplaceCopy(workspace, _opaque->workspacesize, _opaque->copy_out_executor, stream)); + return INFINI_STATUS_SUCCESS; + } + aclOpExecutor *mask_executor = nullptr; size_t workspacesize_mask = 0; @@ -128,7 +218,7 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr); AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr); CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); - CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, workspacesize_mask, mask_executor, stream)); + CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, _opaque->workspacesize, mask_executor, stream)); AclSetTensorAddr(_opaque->executor, 0, tx, (void *)x); AclSetTensorAddr(_opaque->executor, 1, ty, y);