diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index 7b02f93ce..16965de8c 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings, Tensor Embedding::forward(const Tensor &indices) const { // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach auto device_type = device_.getType(); - if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) { + if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ASCEND || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) { // Use op::embedding which supports device-side input and batch dimension return op::embedding(indices->contiguous()->to(device_), weight_); } @@ -72,10 +72,6 @@ Tensor Embedding::forward(const Tensor &indices) const { const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype()); - // Source and destination base pointers - auto *weight_base = weight_->data(); - auto *out_base = out->data(); - // Helper lambda to read index based on dtype with bounds checking auto read_index = [&](size_t i) -> int64_t { auto dtype = indices_cpu->dtype(); @@ -103,6 +99,8 @@ Tensor Embedding::forward(const Tensor &indices) const { if (weight_->device().getType() == Device::Type::CPU) { // CPU path: memcpy row by row + const auto *weight_base = reinterpret_cast(weight_->data()); + auto *out_base = reinterpret_cast(out->data()); for (size_t i = 0; i < num_lookups; ++i) { int64_t idx = read_index(i); if (idx < 0 || idx >= static_cast(num_embeddings_)) { @@ -112,14 +110,17 @@ Tensor Embedding::forward(const Tensor &indices) const { std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); } } else { - // Device path: use stream-ordered D2D copies + // Device fallback: copy rows through Tensor slices so device runtimes own stride/stream handling. + auto flat_out = out->view({num_lookups, embedding_dim_}); for (size_t i = 0; i < num_lookups; ++i) { int64_t idx = read_index(i); if (idx < 0 || idx >= static_cast(num_embeddings_)) { throw std::out_of_range( "Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")"); } - context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); + auto dst = flat_out->narrow({{0, i, 1}}); + auto src = weight_->narrow({{0, static_cast(idx), 1}}); + dst->copy_from(src); } } diff --git a/src/infiniop/ops/embedding/ascend/embedding_ascend.cc b/src/infiniop/ops/embedding/ascend/embedding_ascend.cc new file mode 100644 index 000000000..d54268d9e --- /dev/null +++ b/src/infiniop/ops/embedding/ascend/embedding_ascend.cc @@ -0,0 +1,105 @@ +#include "embedding_ascend.h" +#include "../../../devices/ascend/common_ascend.h" +#include + +namespace op::embedding::ascend { + +struct Descriptor::Opaque { + aclnnTensorDescriptor_t output; + aclnnTensorDescriptor_t input; + aclnnTensorDescriptor_t weight; + void *workspace; + uint64_t workspace_size; + aclOpExecutor *executor; + + ~Opaque() { + delete output; + delete input; + delete weight; + if (workspace != nullptr) { + aclrtFree(workspace); + } + aclDestroyAclOpExecutor(executor); + } +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto handle = reinterpret_cast(handle_); + + CHECK_API_OR(input_desc->dtype() == INFINI_DTYPE_I32 || input_desc->dtype() == INFINI_DTYPE_I64, true, + return INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_API_OR(output_desc->dtype() == weight_desc->dtype(), true, return INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_API_OR(weight_desc->ndim() == 2, true, return INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_API_OR(output_desc->ndim() == input_desc->ndim() + 1, true, return INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto input_shape = input_desc->shape(); + auto output_shape = output_desc->shape(); + auto weight_shape = weight_desc->shape(); + for (size_t i = 0; i < input_desc->ndim(); ++i) { + CHECK_API_OR(output_shape[i] == input_shape[i], true, return INFINI_STATUS_BAD_TENSOR_SHAPE); + } + CHECK_API_OR(output_shape.back() == weight_shape[1], true, return INFINI_STATUS_BAD_TENSOR_SHAPE); + + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + + auto output = new aclnnTensorDescriptor(output_desc); + auto input = new aclnnTensorDescriptor(input_desc); + auto weight = new aclnnTensorDescriptor(weight_desc); + + uint64_t workspace_size = 0; + aclOpExecutor *executor = nullptr; + CHECK_ACL(aclnnEmbeddingGetWorkspaceSize(weight->tensor, input->tensor, output->tensor, + &workspace_size, &executor)); + aclSetAclOpExecutorRepeatable(executor); + + void *workspace = nullptr; + if (workspace_size != 0) { + CHECK_ACL(aclrtMalloc(&workspace, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST)); + } + + *desc_ptr = new Descriptor( + num_indices, + weight_shape[1], + weight_shape[0], + input_desc->dtype(), + weight_desc->dtype(), + new Opaque{output, input, weight, workspace, workspace_size, executor}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + + auto tweight = _opaque->weight->tensor; + auto tinput = _opaque->input->tensor; + auto toutput = _opaque->output->tensor; + + AclSetTensorAddr(_opaque->executor, 0, tweight, const_cast(weight)); + AclSetTensorAddr(_opaque->executor, 1, tinput, const_cast(input)); + AclSetTensorAddr(_opaque->executor, 2, toutput, output); + + CHECK_ACL(aclnnEmbedding(_opaque->workspace, _opaque->workspace_size, + _opaque->executor, stream)); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::embedding::ascend diff --git a/src/infiniop/ops/embedding/ascend/embedding_ascend.h b/src/infiniop/ops/embedding/ascend/embedding_ascend.h new file mode 100644 index 000000000..7653a5866 --- /dev/null +++ b/src/infiniop/ops/embedding/ascend/embedding_ascend.h @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_ASCEND_H__ +#define __EMBEDDING_ASCEND_H__ + +#include "../embedding.h" + +DESCRIPTOR(ascend) + +#endif // __EMBEDDING_ASCEND_H__ diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc index 4741945c7..aa8e7fd42 100644 --- a/src/infiniop/ops/embedding/operator.cc +++ b/src/infiniop/ops/embedding/operator.cc @@ -8,6 +8,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API) #include "nvidia/embedding_nvidia.cuh" #endif +#ifdef ENABLE_ASCEND_API +#include "ascend/embedding_ascend.h" +#endif #ifdef ENABLE_METAX_API #include "metax/embedding_metax.cuh" #endif @@ -51,6 +54,9 @@ __INFINI_C infiniStatus_t infiniopCreateEmbeddingDescriptor( #ifdef ENABLE_HYGON_API CREATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_ASCEND_API + CREATE(INFINI_DEVICE_ASCEND, ascend); +#endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif @@ -97,6 +103,9 @@ __INFINI_C infiniStatus_t infiniopEmbedding( #ifdef ENABLE_HYGON_API CALCULATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_ASCEND_API + CALCULATE(INFINI_DEVICE_ASCEND, ascend); +#endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif