diff --git a/src/infiniop/ops/softmax/bang/softmax_bang.cc b/src/infiniop/ops/softmax/bang/softmax_bang.cc new file mode 100644 index 000000000..a92916734 --- /dev/null +++ b/src/infiniop/ops/softmax/bang/softmax_bang.cc @@ -0,0 +1,96 @@ +#include "softmax_bang.h" +#include "../../../devices/bang/common_bang.h" + +namespace op::softmax::bang { + +struct Descriptor::Opaque { + std::shared_ptr internal; + cnnlTensorDescriptor_t x_desc = nullptr; + cnnlTensorDescriptor_t y_desc = nullptr; + + ~Opaque() { + if (x_desc) { + cnnlDestroyTensorDescriptor(x_desc); + } + if (y_desc) { + cnnlDestroyTensorDescriptor(y_desc); + } + } +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +static infiniStatus_t setSoftmaxTensor(cnnlTensorDescriptor_t desc, const SoftmaxInfo &info) { + int dims[3] = { + static_cast(info.othersize / info.stride), + static_cast(info.dimsize), + static_cast(info.stride), + }; + CHECK_BANG(cnnlSetTensorDescriptor( + desc, + CNNL_LAYOUT_ARRAY, + device::bang::getCnnlDtype(info.dtype), + 3, + dims)); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis) { + auto result = SoftmaxInfo::create(y_desc, x_desc, axis); + CHECK_RESULT(result); + auto info = result.take(); + + cnnlTensorDescriptor_t cnnl_x = nullptr; + cnnlTensorDescriptor_t cnnl_y = nullptr; + CHECK_BANG(cnnlCreateTensorDescriptor(&cnnl_x)); + CHECK_BANG(cnnlCreateTensorDescriptor(&cnnl_y)); + CHECK_STATUS(setSoftmaxTensor(cnnl_x, info)); + CHECK_STATUS(setSoftmaxTensor(cnnl_y, info)); + + auto handle_bang = reinterpret_cast(handle); + *desc_ptr = new Descriptor( + new Opaque{handle_bang->internal(), cnnl_x, cnnl_y}, + std::move(info), + 0, + handle->device, + handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, + const void *x, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto queue = reinterpret_cast(stream); + CHECK_STATUS(_opaque->internal->useCnnl( + queue, + [&](cnnlHandle_t handle) { + CHECK_BANG(cnnlSoftmaxForward( + handle, + CNNL_SOFTMAX_ACCURATE, + CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION, + nullptr, + _opaque->x_desc, + x, + nullptr, + _opaque->y_desc, + y)); + return INFINI_STATUS_SUCCESS; + })); + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::softmax::bang diff --git a/src/infiniop/ops/softmax/bang/softmax_bang.h b/src/infiniop/ops/softmax/bang/softmax_bang.h new file mode 100644 index 000000000..5cf909ab5 --- /dev/null +++ b/src/infiniop/ops/softmax/bang/softmax_bang.h @@ -0,0 +1,8 @@ +#ifndef __SOFTMAX_BANG_H__ +#define __SOFTMAX_BANG_H__ + +#include "../softmax.h" + +DESCRIPTOR(bang) + +#endif // __SOFTMAX_BANG_H__ diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc index 54916f121..1664e257f 100644 --- a/src/infiniop/ops/softmax/operator.cc +++ b/src/infiniop/ops/softmax/operator.cc @@ -2,6 +2,10 @@ #include "../../handle.h" #include "infiniop/ops/softmax.h" +#ifdef ENABLE_CAMBRICON_API +#include "bang/softmax_bang.h" +#endif + #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API) #include "nvidia/softmax_nvidia.cuh" #endif @@ -36,6 +40,9 @@ __INFINI_C infiniStatus_t infiniopCreateSoftmaxDescriptor( #endif #ifdef ENABLE_ALI_API CREATE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -64,6 +71,9 @@ __INFINI_C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescrip #endif #ifdef ENABLE_ALI_API GET(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -97,6 +107,9 @@ __INFINI_C infiniStatus_t infiniopSoftmax( #endif #ifdef ENABLE_ALI_API CALCULATE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -125,6 +138,9 @@ __INFINI_C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescri #endif #ifdef ENABLE_ALI_API DESTROY(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_CAMBRICON_API + DESTROY(INFINI_DEVICE_CAMBRICON, bang); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;