From 2d66dc359b0346bcaac5db7c0393561e4aa71c69 Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Wed, 3 Jun 2026 13:16:18 +0800 Subject: [PATCH] issue/1193 - Add a forward_ function to the Embedding module. --- include/infinicore/nn/embedding.hpp | 2 ++ src/infinicore/nn/embedding.cc | 25 +++++++++++++++++-------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/infinicore/nn/embedding.hpp b/include/infinicore/nn/embedding.hpp index 50a387325..7a0d3e116 100644 --- a/include/infinicore/nn/embedding.hpp +++ b/include/infinicore/nn/embedding.hpp @@ -60,6 +60,8 @@ class Embedding : public Module { * Input shape: [10] -> Output shape: [10, embedding_dim] */ Tensor forward(const Tensor &indices) const; + + Tensor forward_(Tensor &output, const Tensor &indices) const; // Module information size_t num_embeddings() const { return num_embeddings_; } diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index 7b02f93ce..8105f6fc2 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -43,13 +43,6 @@ Embedding::Embedding(size_t num_embeddings, } Tensor Embedding::forward(const Tensor &indices) const { - // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach - auto device_type = device_.getType(); - if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) { - // Use op::embedding which supports device-side input and batch dimension - return op::embedding(indices->contiguous()->to(device_), weight_); - } - // Get the shape of indices auto indices_shape = indices->shape(); @@ -58,12 +51,28 @@ Tensor Embedding::forward(const Tensor &indices) const { output_shape.push_back(embedding_dim_); // Create output tensor on the same device as weight - auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device()); + auto output_embeds = Tensor::empty(output_shape, weight_->dtype(), weight_->device()); + + return this->forward_(output_embeds, indices); +} +Tensor Embedding::forward_(Tensor &output_embeds, const Tensor &indices) const { + // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach + auto device_type = device_.getType(); + if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) { + // Use op::embedding which supports device-side input and batch dimension + op::embedding_(output_embeds, indices->contiguous()->to(device_), weight_); + return output_embeds; + } + + auto out = output_embeds; // Flatten indices for sequential row copies auto cpu_device = Device(Device::Type::CPU, 0); auto indices_cpu = indices->to(cpu_device)->contiguous(); + // Get the shape of indices + auto indices_shape = indices->shape(); + // Calculate total number of lookups size_t num_lookups = 1; for (auto dim : indices_shape) {