Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/infiniop/ops/add/ascend/add_ascend.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "add_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_add.h>
#include <cstdint>
#include <cstdlib>

namespace op::add::ascend {

Expand All @@ -10,18 +12,20 @@ struct Descriptor::Opaque {
aclnnTensorDescriptor_t b;
aclnnTensorDescriptor_t c;
aclnnScalarDescriptor_t alpha;
void *alpha_value;
size_t workspaceSize;
aclOpExecutor *executor;

Opaque(aclnnTensorDescriptor_t a_, aclnnTensorDescriptor_t b_, aclnnTensorDescriptor_t c_,
aclnnScalarDescriptor_t alpha_, size_t ws, aclOpExecutor *exec)
: a(a_), b(b_), c(c_), alpha(alpha_), workspaceSize(ws), executor(exec) {}
aclnnScalarDescriptor_t alpha_, void *alpha_value_, size_t ws, aclOpExecutor *exec)
: a(a_), b(b_), c(c_), alpha(alpha_), alpha_value(alpha_value_), workspaceSize(ws), executor(exec) {}

~Opaque() {
delete a;
delete b;
delete c;
delete alpha;
std::free(alpha_value);
aclDestroyAclOpExecutor(executor);
}
};
Expand Down Expand Up @@ -53,10 +57,37 @@ infiniStatus_t Descriptor::create(
aclnnTensorDescriptor_t b = new aclnnTensorDescriptor(b_desc);
aclnnTensorDescriptor_t c = new aclnnTensorDescriptor(c_desc);

// Default alpha = 1.0
float alpha_value = 1.0f;
void *alpha_value = nullptr;
size_t alpha_value_size = 0;
infiniDtype_t alpha_dtype = INFINI_DTYPE_F32;

#define SET_ALPHA(TYPE, DTYPE, VALUE) \
do { \
alpha_value_size = sizeof(TYPE); \
alpha_value = std::malloc(alpha_value_size); \
if (alpha_value == nullptr) { \
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; \
} \
*static_cast<TYPE *>(alpha_value) = (TYPE)(VALUE); \
alpha_dtype = DTYPE; \
} while (0)

switch (c_desc->dtype()) {
case INFINI_DTYPE_I32:
SET_ALPHA(int32_t, INFINI_DTYPE_I32, 1);
break;
case INFINI_DTYPE_I64:
SET_ALPHA(int64_t, INFINI_DTYPE_I64, 1);
break;
default:
SET_ALPHA(float, INFINI_DTYPE_F32, 1.0f);
break;
}

#undef SET_ALPHA

aclnnScalarDescriptor_t alpha = new aclnnScalarDescriptor(
INFINI_DTYPE_F32, &alpha_value, sizeof(float));
alpha_dtype, alpha_value, alpha_value_size);

size_t workspace_size = 0;
aclOpExecutor *executor = nullptr;
Expand All @@ -72,7 +103,7 @@ infiniStatus_t Descriptor::create(
aclSetAclOpExecutorRepeatable(executor);

*desc_ptr = new Descriptor(
new Opaque{a, b, c, alpha, workspace_size, executor},
new Opaque{a, b, c, alpha, alpha_value, workspace_size, executor},
result.take(),
workspace_size,
handle_ascend->device,
Expand Down
2 changes: 1 addition & 1 deletion src/infiniop/ops/add/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AddInfo {
auto dtype = c_desc->dtype();

// Check dtype compatibility
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);

// Check shape compatibility (broadcast)
auto c_shape = c_desc->shape();
Expand Down
Loading