diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 747735bcd..a06c1cc7e 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -26,6 +26,7 @@ #include "ops/flash_attention.hpp" #include "ops/fmin.hpp" #include "ops/fmod.hpp" +#include "ops/fused_gated_delta_net_gating.hpp" #include "ops/gelu.hpp" #include "ops/gelutanh.hpp" #include "ops/hardswish.hpp" diff --git a/include/infinicore/ops/fused_gated_delta_net_gating.hpp b/include/infinicore/ops/fused_gated_delta_net_gating.hpp new file mode 100644 index 000000000..a61fd9937 --- /dev/null +++ b/include/infinicore/ops/fused_gated_delta_net_gating.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(FusedGatedDeltaNetGating, + Tensor, + Tensor, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + float, + float); + +std::pair fused_gated_delta_net_gating(const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta = 1.0f, + float threshold = 20.0f); + +void fused_gated_delta_net_gating_(Tensor g, + Tensor beta_output, + const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta = 1.0f, + float threshold = 20.0f); + +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index de33a7a4b..2d0c9bf75 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -52,6 +52,7 @@ #include "infiniop/ops/floor_divide.h" #include "infiniop/ops/fmin.h" #include "infiniop/ops/fmod.h" +#include "infiniop/ops/fused_gated_delta_net_gating.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gelutanh.h" #include "infiniop/ops/gemm.h" diff --git a/include/infiniop/ops/fused_gated_delta_net_gating.h b/include/infiniop/ops/fused_gated_delta_net_gating.h new file mode 100644 index 000000000..dafea6710 --- /dev/null +++ b/include/infiniop/ops/fused_gated_delta_net_gating.h @@ -0,0 +1,43 @@ +#ifndef __INFINIOP_FUSED_GATED_DELTA_NET_GATING_API_H__ +#define __INFINIOP_FUSED_GATED_DELTA_NET_GATING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopFusedGatedDeltaNetGatingDescriptor_t; + +__INFINI_C __export infiniStatus_t +infiniopCreateFusedGatedDeltaNetGatingDescriptor( + infiniopHandle_t handle, + infiniopFusedGatedDeltaNetGatingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_output_desc, + infiniopTensorDescriptor_t A_log_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t dt_bias_desc, + float beta, + float threshold); + +__INFINI_C __export infiniStatus_t +infiniopGetFusedGatedDeltaNetGatingWorkspaceSize( + infiniopFusedGatedDeltaNetGatingDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t +infiniopFusedGatedDeltaNetGating( + infiniopFusedGatedDeltaNetGatingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *g, + void *beta_output, + const void *A_log, + const void *a, + const void *b, + const void *dt_bias, + void *stream); + +__INFINI_C __export infiniStatus_t +infiniopDestroyFusedGatedDeltaNetGatingDescriptor( + infiniopFusedGatedDeltaNetGatingDescriptor_t desc); + +#endif diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index a28128a1d..6f91e997f 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -7,6 +7,7 @@ from .causal_softmax import causal_softmax from .embedding import embedding from .flash_attention import flash_attention +from .fused_gated_delta_net_gating import fused_gated_delta_net_gating from .gaussian_nll_loss import gaussian_nll_loss from .hardswish import hardswish from .hardtanh import hardtanh @@ -43,6 +44,7 @@ "causal_softmax", "embedding", "flash_attention", + "fused_gated_delta_net_gating", "gaussian_nll_loss", "interpolate", "linear", diff --git a/python/infinicore/nn/functional/fused_gated_delta_net_gating.py b/python/infinicore/nn/functional/fused_gated_delta_net_gating.py new file mode 100644 index 000000000..faa6a339f --- /dev/null +++ b/python/infinicore/nn/functional/fused_gated_delta_net_gating.py @@ -0,0 +1,37 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def fused_gated_delta_net_gating( + A_log: Tensor, + a: Tensor, + b: Tensor, + dt_bias: Tensor, + beta: float = 1.0, + threshold: float = 20.0, + *, + out: tuple[Tensor, Tensor] | None = None, +) -> tuple[Tensor, Tensor]: + if out is None: + g, beta_output = _infinicore.fused_gated_delta_net_gating( + A_log._underlying, + a._underlying, + b._underlying, + dt_bias._underlying, + beta, + threshold, + ) + return Tensor(g), Tensor(beta_output) + + g, beta_output = out + _infinicore.fused_gated_delta_net_gating_( + g._underlying, + beta_output._underlying, + A_log._underlying, + a._underlying, + b._underlying, + dt_bias._underlying, + beta, + threshold, + ) + return g, beta_output diff --git a/src/infinicore/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating.cc b/src/infinicore/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating.cc new file mode 100644 index 000000000..6ce94925b --- /dev/null +++ b/src/infinicore/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating.cc @@ -0,0 +1,85 @@ +#include "infinicore/ops/fused_gated_delta_net_gating.hpp" + +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FusedGatedDeltaNetGating); + +FusedGatedDeltaNetGating::FusedGatedDeltaNetGating(Tensor g, + Tensor beta_output, + const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta, + float threshold) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(g, beta_output, A_log, a, b, dt_bias); + INFINICORE_GRAPH_OP_DISPATCH(g->device().getType(), g, beta_output, A_log, a, b, dt_bias, beta, threshold); +} + +void FusedGatedDeltaNetGating::execute(Tensor g, + Tensor beta_output, + const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta, + float threshold) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(FusedGatedDeltaNetGating, g, beta_output, A_log, a, b, dt_bias, beta, threshold); +} + +static void validate_inputs(const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias) { + if (a->shape().size() != 3 || b->shape().size() != 3) { + throw std::runtime_error("fused_gated_delta_net_gating expects a and b with shape [batch_size, seq_len, hidden]"); + } + if (a->shape() != b->shape()) { + throw std::runtime_error("fused_gated_delta_net_gating expects a and b to have the same shape"); + } + if (A_log->shape().size() != 1 || dt_bias->shape().size() != 1) { + throw std::runtime_error("fused_gated_delta_net_gating expects A_log and dt_bias with shape [hidden]"); + } + if (A_log->shape()[0] != a->shape()[2] || dt_bias->shape()[0] != a->shape()[2]) { + throw std::runtime_error("fused_gated_delta_net_gating hidden dimension mismatch"); + } +} + +std::pair fused_gated_delta_net_gating(const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta, + float threshold) { + validate_inputs(A_log, a, b, dt_bias); + + Tensor g = Tensor::empty(a->shape(), DataType::F32, a->device()); + Tensor beta_output = Tensor::empty(a->shape(), DataType::F32, a->device()); + fused_gated_delta_net_gating_(g, beta_output, A_log, a, b, dt_bias, beta, threshold); + return {g, beta_output}; +} + +void fused_gated_delta_net_gating_(Tensor g, + Tensor beta_output, + const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta, + float threshold) { + validate_inputs(A_log, a, b, dt_bias); + if (g->shape() != a->shape() || beta_output->shape() != a->shape()) { + throw std::runtime_error("fused_gated_delta_net_gating_ expects outputs with shape [batch_size, seq_len, hidden]"); + } + if (g->dtype() != DataType::F32 || beta_output->dtype() != DataType::F32) { + throw std::runtime_error("fused_gated_delta_net_gating_ expects float32 outputs"); + } + + FusedGatedDeltaNetGating::execute(g, beta_output, A_log, a, b, dt_bias, beta, threshold); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating_infiniop.cc b/src/infinicore/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating_infiniop.cc new file mode 100644 index 000000000..66aeeaa18 --- /dev/null +++ b/src/infinicore/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating_infiniop.cc @@ -0,0 +1,71 @@ +#include "infinicore/ops/fused_gated_delta_net_gating.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::fused_gated_delta_net_gating_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FusedGatedDeltaNetGating, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, g, beta_output, A_log, a, b, dt_bias; +}; + +void *plan(Tensor g, + Tensor beta_output, + const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta, + float threshold) { + size_t seed = hash_combine(g, beta_output, A_log, a, b, dt_bias, beta, threshold); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, FusedGatedDeltaNetGating, seed, + g->desc(), + beta_output->desc(), + A_log->desc(), + a->desc(), + b->desc(), + dt_bias->desc(), + beta, + threshold); + + INFINIOP_WORKSPACE_TENSOR(workspace, FusedGatedDeltaNetGating, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(g), + graph::GraphTensor(beta_output), + graph::GraphTensor(A_log), + graph::GraphTensor(a), + graph::GraphTensor(b), + graph::GraphTensor(dt_bias)}; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopFusedGatedDeltaNetGating( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->g->data(), + planned->beta_output->data(), + planned->A_log->data(), + planned->a->data(), + planned->b->data(), + planned->dt_bias->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(FusedGatedDeltaNetGating, &plan, &run, &cleanup); + +} // namespace infinicore::op::fused_gated_delta_net_gating_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 0eb4fef98..a02973a45 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -47,6 +47,7 @@ #include "ops/floor_divide.hpp" #include "ops/fmin.hpp" #include "ops/fmod.hpp" +#include "ops/fused_gated_delta_net_gating.hpp" #include "ops/gaussian_nll_loss.hpp" #include "ops/hardswish.hpp" #include "ops/hardtanh.hpp" @@ -160,6 +161,7 @@ inline void bind(py::module &m) { bind_hinge_embedding_loss(m); bind_kv_caching(m); bind_fmod(m); + bind_fused_gated_delta_net_gating(m); bind_fmin(m); bind_cat(m); bind_causal_softmax(m); diff --git a/src/infinicore/pybind11/ops/fused_gated_delta_net_gating.hpp b/src/infinicore/pybind11/ops/fused_gated_delta_net_gating.hpp new file mode 100644 index 000000000..e43566977 --- /dev/null +++ b/src/infinicore/pybind11/ops/fused_gated_delta_net_gating.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include + +#include "infinicore/ops/fused_gated_delta_net_gating.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_fused_gated_delta_net_gating(py::module &m) { + m.def( + "fused_gated_delta_net_gating", + [](const Tensor &A_log, + const Tensor &a, + const Tensor &b, + const Tensor &dt_bias, + float beta, + float threshold) { + auto result = op::fused_gated_delta_net_gating(A_log, a, b, dt_bias, beta, threshold); + return py::make_tuple(result.first, result.second); + }, + py::arg("A_log"), + py::arg("a"), + py::arg("b"), + py::arg("dt_bias"), + py::arg("beta") = 1.0f, + py::arg("threshold") = 20.0f, + R"doc(Fused GatedDeltaNet gating out-of-place.)doc"); + + m.def("fused_gated_delta_net_gating_", + &op::fused_gated_delta_net_gating_, + py::arg("g"), + py::arg("beta_output"), + py::arg("A_log"), + py::arg("a"), + py::arg("b"), + py::arg("dt_bias"), + py::arg("beta") = 1.0f, + py::arg("threshold") = 20.0f, + R"doc(Fused GatedDeltaNet gating writing to provided outputs.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating.h b/src/infiniop/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating.h new file mode 100644 index 000000000..6be2670e2 --- /dev/null +++ b/src/infiniop/ops/fused_gated_delta_net_gating/fused_gated_delta_net_gating.h @@ -0,0 +1,35 @@ +#ifndef __FUSED_GATED_DELTA_NET_GATING_H__ +#define __FUSED_GATED_DELTA_NET_GATING_H__ + +#include "../../operator.h" +#include "info.h" + +#include + +#define DESCRIPTOR(NAMESPACE) \ + namespace op::fused_gated_delta_net_gating::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + FusedGatedDeltaNetGatingInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(Opaque *opaque, FusedGatedDeltaNetGatingInfo info, size_t workspace_size, infiniDevice_t device_type, int device_id) \ + : InfiniopDescriptor{device_type, device_id}, _opaque(opaque), _info(std::move(info)), _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t g_desc, infiniopTensorDescriptor_t beta_output_desc, \ + infiniopTensorDescriptor_t A_log_desc, infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t dt_bias_desc, \ + float beta, float threshold); \ + \ + infiniStatus_t calculate(void *workspace, size_t workspace_size, void *g, void *beta_output, \ + const void *A_log, const void *a, const void *b, const void *dt_bias, void *stream) const; \ + }; \ + } + +#endif diff --git a/src/infiniop/ops/fused_gated_delta_net_gating/info.h b/src/infiniop/ops/fused_gated_delta_net_gating/info.h new file mode 100644 index 000000000..541751353 --- /dev/null +++ b/src/infiniop/ops/fused_gated_delta_net_gating/info.h @@ -0,0 +1,90 @@ +#ifndef __FUSED_GATED_DELTA_NET_GATING_INFO_H__ +#define __FUSED_GATED_DELTA_NET_GATING_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +#include + +namespace op::fused_gated_delta_net_gating { + +class FusedGatedDeltaNetGatingInfo { + FusedGatedDeltaNetGatingInfo() = default; + +public: + infiniDtype_t input_dtype; + size_t batch_size; + size_t seq_len; + size_t hidden; + std::vector g_strides; + std::vector beta_output_strides; + std::vector A_log_strides; + std::vector a_strides; + std::vector b_strides; + std::vector dt_bias_strides; + float beta; + float threshold; + + size_t numel() const { + return batch_size * seq_len * hidden; + } + + static utils::Result create( + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_output_desc, + infiniopTensorDescriptor_t A_log_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t dt_bias_desc, + float beta, + float threshold) { + + if (g_desc->dtype() != INFINI_DTYPE_F32 || beta_output_desc->dtype() != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto input_dtype = a_desc->dtype(); + if (input_dtype != INFINI_DTYPE_F32 && input_dtype != INFINI_DTYPE_F16 && input_dtype != INFINI_DTYPE_BF16) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (b_desc->dtype() != input_dtype || A_log_desc->dtype() != input_dtype || dt_bias_desc->dtype() != input_dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (g_desc->ndim() != 3 || beta_output_desc->ndim() != 3 || a_desc->ndim() != 3 || b_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (A_log_desc->ndim() != 1 || dt_bias_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + const auto &shape = a_desc->shape(); + if (shape != b_desc->shape() || shape != g_desc->shape() || shape != beta_output_desc->shape()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + size_t hidden = shape[2]; + if (A_log_desc->shape()[0] != hidden || dt_bias_desc->shape()[0] != hidden) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + FusedGatedDeltaNetGatingInfo info; + info.input_dtype = input_dtype; + info.batch_size = shape[0]; + info.seq_len = shape[1]; + info.hidden = hidden; + info.g_strides = g_desc->strides(); + info.beta_output_strides = beta_output_desc->strides(); + info.A_log_strides = A_log_desc->strides(); + info.a_strides = a_desc->strides(); + info.b_strides = b_desc->strides(); + info.dt_bias_strides = dt_bias_desc->strides(); + info.beta = beta; + info.threshold = threshold; + return utils::Result(info); + } +}; + +} // namespace op::fused_gated_delta_net_gating + +#endif diff --git a/src/infiniop/ops/fused_gated_delta_net_gating/nvidia/fused_gated_delta_net_gating_nvidia.cu b/src/infiniop/ops/fused_gated_delta_net_gating/nvidia/fused_gated_delta_net_gating_nvidia.cu new file mode 100644 index 000000000..627f026bf --- /dev/null +++ b/src/infiniop/ops/fused_gated_delta_net_gating/nvidia/fused_gated_delta_net_gating_nvidia.cu @@ -0,0 +1,196 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "fused_gated_delta_net_gating_nvidia.cuh" + +#include +#include + +namespace op::fused_gated_delta_net_gating::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_output_desc, + infiniopTensorDescriptor_t A_log_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t dt_bias_desc, + float beta, + float threshold) { + + auto result = FusedGatedDeltaNetGatingInfo::create( + g_desc, beta_output_desc, A_log_desc, a_desc, b_desc, dt_bias_desc, beta, threshold); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + result.take(), + 0, + handle->device, + handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +namespace { + +template +__device__ __forceinline__ float load_as_float(const T *ptr, ptrdiff_t offset) { + return static_cast(ptr[offset]); +} + +template <> +__device__ __forceinline__ float load_as_float(const half *ptr, ptrdiff_t offset) { + return __half2float(ptr[offset]); +} + +template <> +__device__ __forceinline__ float load_as_float<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) { + return __bfloat162float(ptr[offset]); +} + +__device__ __forceinline__ float sigmoidf_stable(float x) { + if (x >= 0.0f) { + float z = expf(-x); + return 1.0f / (1.0f + z); + } + float z = expf(x); + return z / (1.0f + z); +} + +__device__ __forceinline__ float softplus_beta_threshold(float x, float beta, float threshold) { + float bx = beta * x; + return bx <= threshold ? log1pf(expf(bx)) / beta : x; +} + +template +__global__ void fused_gated_delta_net_gating_kernel( + float *g, + float *beta_output, + const T *A_log, + const T *a, + const T *b, + const T *dt_bias, + size_t total, + size_t seq_len, + size_t hidden, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2, + ptrdiff_t A_log_s0, + ptrdiff_t a_s0, + ptrdiff_t a_s1, + ptrdiff_t a_s2, + ptrdiff_t b_s0, + ptrdiff_t b_s1, + ptrdiff_t b_s2, + ptrdiff_t dt_bias_s0, + float beta, + float threshold) { + + size_t linear = blockIdx.x * blockDim.x + threadIdx.x; + if (linear >= total) { + return; + } + + size_t h = linear % hidden; + size_t tmp = linear / hidden; + size_t s = tmp % seq_len; + size_t batch = tmp / seq_len; + + ptrdiff_t g_off = static_cast(batch) * g_s0 + static_cast(s) * g_s1 + static_cast(h) * g_s2; + ptrdiff_t beta_off = static_cast(batch) * beta_s0 + static_cast(s) * beta_s1 + static_cast(h) * beta_s2; + ptrdiff_t a_off = static_cast(batch) * a_s0 + static_cast(s) * a_s1 + static_cast(h) * a_s2; + ptrdiff_t b_off = static_cast(batch) * b_s0 + static_cast(s) * b_s1 + static_cast(h) * b_s2; + + float x = load_as_float(a, a_off) + load_as_float(dt_bias, static_cast(h) * dt_bias_s0); + g[g_off] = -expf(load_as_float(A_log, static_cast(h) * A_log_s0)) * softplus_beta_threshold(x, beta, threshold); + beta_output[beta_off] = sigmoidf_stable(load_as_float(b, b_off)); +} + +template +infiniStatus_t launch_kernel(const FusedGatedDeltaNetGatingInfo &info, + float *g, + float *beta_output, + const void *A_log, + const void *a, + const void *b, + const void *dt_bias, + cudaStream_t stream) { + size_t total = info.numel(); + if (total == 0) { + return INFINI_STATUS_SUCCESS; + } + + constexpr int block = 256; + int grid = static_cast((total + block - 1) / block); + fused_gated_delta_net_gating_kernel<<>>( + g, + beta_output, + static_cast(A_log), + static_cast(a), + static_cast(b), + static_cast(dt_bias), + total, + info.seq_len, + info.hidden, + info.g_strides[0], + info.g_strides[1], + info.g_strides[2], + info.beta_output_strides[0], + info.beta_output_strides[1], + info.beta_output_strides[2], + info.A_log_strides[0], + info.a_strides[0], + info.a_strides[1], + info.a_strides[2], + info.b_strides[0], + info.b_strides[1], + info.b_strides[2], + info.dt_bias_strides[0], + info.beta, + info.threshold); + return INFINI_STATUS_SUCCESS; +} + +} // namespace + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *g, + void *beta_output, + const void *A_log, + const void *a, + const void *b, + const void *dt_bias, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + cudaStream_t cuda_stream = reinterpret_cast(stream); + switch (_info.input_dtype) { + case INFINI_DTYPE_F32: + return launch_kernel(_info, static_cast(g), static_cast(beta_output), A_log, a, b, dt_bias, cuda_stream); + case INFINI_DTYPE_F16: + return launch_kernel(_info, static_cast(g), static_cast(beta_output), A_log, a, b, dt_bias, cuda_stream); + case INFINI_DTYPE_BF16: + return launch_kernel<__nv_bfloat16>(_info, static_cast(g), static_cast(beta_output), A_log, a, b, dt_bias, cuda_stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::fused_gated_delta_net_gating::nvidia diff --git a/src/infiniop/ops/fused_gated_delta_net_gating/nvidia/fused_gated_delta_net_gating_nvidia.cuh b/src/infiniop/ops/fused_gated_delta_net_gating/nvidia/fused_gated_delta_net_gating_nvidia.cuh new file mode 100644 index 000000000..a427fde6a --- /dev/null +++ b/src/infiniop/ops/fused_gated_delta_net_gating/nvidia/fused_gated_delta_net_gating_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __FUSED_GATED_DELTA_NET_GATING_NVIDIA_CUH__ +#define __FUSED_GATED_DELTA_NET_GATING_NVIDIA_CUH__ + +#include "../fused_gated_delta_net_gating.h" + +DESCRIPTOR(nvidia) + +#endif diff --git a/src/infiniop/ops/fused_gated_delta_net_gating/operator.cc b/src/infiniop/ops/fused_gated_delta_net_gating/operator.cc new file mode 100644 index 000000000..0a992c874 --- /dev/null +++ b/src/infiniop/ops/fused_gated_delta_net_gating/operator.cc @@ -0,0 +1,155 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/fused_gated_delta_net_gating.h" + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) +#include "nvidia/fused_gated_delta_net_gating_nvidia.cuh" +#endif + +__INFINI_C __export infiniStatus_t +infiniopCreateFusedGatedDeltaNetGatingDescriptor( + infiniopHandle_t handle, + infiniopFusedGatedDeltaNetGatingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_output_desc, + infiniopTensorDescriptor_t A_log_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t dt_bias_desc, + float beta, + float threshold) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::fused_gated_delta_net_gating::NAMESPACE::Descriptor::create( \ + handle, reinterpret_cast(desc_ptr), \ + g_desc, beta_output_desc, A_log_desc, a_desc, b_desc, dt_bias_desc, beta, threshold) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_ALI_API + CREATE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C __export infiniStatus_t +infiniopGetFusedGatedDeltaNetGatingWorkspaceSize( + infiniopFusedGatedDeltaNetGatingDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_ALI_API + GET(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__INFINI_C __export infiniStatus_t +infiniopFusedGatedDeltaNetGating( + infiniopFusedGatedDeltaNetGatingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *g, + void *beta_output, + const void *A_log, + const void *a, + const void *b, + const void *dt_bias, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, g, beta_output, A_log, a, b, dt_bias, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_ALI_API + CALCULATE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C __export infiniStatus_t +infiniopDestroyFusedGatedDeltaNetGatingDescriptor( + infiniopFusedGatedDeltaNetGatingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + DELETE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_ALI_API + DELETE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infinicore/ops/fused_gated_delta_net_gating.py b/test/infinicore/ops/fused_gated_delta_net_gating.py new file mode 100644 index 000000000..ec8283d40 --- /dev/null +++ b/test/infinicore/ops/fused_gated_delta_net_gating.py @@ -0,0 +1,109 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import torch.nn.functional as F +import infinicore +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorSpec, + TestCase, +) + + +_TEST_CASES_DATA = [ + ((2, 1, 8), None, None), + ((2, 3, 17), None, None), + ((2, 3, 17), (80, 20, 1), (2,)), +] + +_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 2e-3, "rtol": 2e-3}, + infinicore.float32: {"atol": 1e-6, "rtol": 1e-6}, + infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2}, +} + + +def torch_fused_gdn_gating(A_log, a, b, dt_bias, beta=1.0, threshold=20.0, out=None): + x = a.float() + dt_bias.float().view(1, 1, -1) + softplus_x = torch.where( + beta * x <= threshold, + F.softplus(x, beta=beta, threshold=threshold), + x, + ) + g = -A_log.float().exp().view(1, 1, -1) * softplus_x + beta_output = b.float().sigmoid() + if out is not None: + out_g, out_beta = out + out_g.copy_(g) + out_beta.copy_(beta_output) + return out + return g, beta_output + + +def parse_test_cases(): + tests = [] + for shape, tensor_strides, hidden_strides in _TEST_CASES_DATA: + hidden = shape[-1] + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + A_log = TensorSpec.from_tensor((hidden,), hidden_strides, dtype) + a = TensorSpec.from_tensor(shape, tensor_strides, dtype) + b = TensorSpec.from_tensor(shape, tensor_strides, dtype) + dt_bias = TensorSpec.from_tensor((hidden,), hidden_strides, dtype) + kwargs = {"beta": 1.0, "threshold": 20.0} + + tests.append( + TestCase( + inputs=[A_log, a, b, dt_bias], + kwargs=kwargs, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="FusedGatedDeltaNetGating - OUT_OF_PLACE", + output_count=2, + ) + ) + + out_g = TensorSpec.from_tensor(shape, None, infinicore.float32) + out_beta = TensorSpec.from_tensor(shape, None, infinicore.float32) + tests.append( + TestCase( + inputs=[A_log, a, b, dt_bias], + kwargs=kwargs.copy(), + output_specs=[out_g, out_beta], + comparison_target="out", + tolerance=tol, + description="FusedGatedDeltaNetGating - INPLACE(out)", + output_count=2, + ) + ) + return tests + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("FusedGatedDeltaNetGating") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_fused_gdn_gating(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore.nn.functional.fused_gated_delta_net_gating(*args, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infiniop/fused_gated_delta_net_gating.py b/test/infiniop/fused_gated_delta_net_gating.py new file mode 100644 index 000000000..3c5d820d5 --- /dev/null +++ b/test/infiniop/fused_gated_delta_net_gating.py @@ -0,0 +1,159 @@ +import ctypes +from ctypes import c_float, c_uint64 + +import torch +import torch.nn.functional as F +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + test_operator, +) + + +_TEST_CASES = [ + ((2, 1, 8), None, None, None), + ((2, 3, 17), None, None, None), + ((2, 3, 17), (80, 20, 1), (80, 20, 1), (2,)), +] + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2}, +} + +DEBUG = False + + +def torch_fused_gdn_gating(g, beta_output, A_log, a, b, dt_bias, beta, threshold): + x = a.float() + dt_bias.float().view(1, 1, -1) + softplus_x = torch.where( + beta * x <= threshold, + F.softplus(x, beta=beta, threshold=threshold), + x, + ) + g.copy_(-A_log.float().exp().view(1, 1, -1) * softplus_x) + beta_output.copy_(b.float().sigmoid()) + + +def test( + handle, + device, + shape, + tensor_stride=None, + out_stride=None, + hidden_stride=None, + dtype=torch.float16, + sync=None, +): + beta = 1.0 + threshold = 20.0 + hidden = shape[-1] + + a = TestTensor(shape, tensor_stride, dtype, device) + b = TestTensor(shape, tensor_stride, dtype, device) + A_log = TestTensor((hidden,), hidden_stride, dtype, device) + dt_bias = TestTensor((hidden,), hidden_stride, dtype, device) + g = TestTensor(shape, out_stride, InfiniDtype.F32, device, mode="ones") + beta_output = TestTensor(shape, out_stride, InfiniDtype.F32, device, mode="ones") + + if g.is_broadcast() or beta_output.is_broadcast(): + return + + print( + f"Testing FusedGatedDeltaNetGating on {InfiniDeviceNames[device]} " + f"shape:{shape} dtype:{InfiniDtypeNames[dtype]}" + ) + + torch_fused_gdn_gating( + g.torch_tensor(), + beta_output.torch_tensor(), + A_log.torch_tensor(), + a.torch_tensor(), + b.torch_tensor(), + dt_bias.torch_tensor(), + beta, + threshold, + ) + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateFusedGatedDeltaNetGatingDescriptor( + handle, + ctypes.byref(descriptor), + g.descriptor, + beta_output.descriptor, + A_log.descriptor, + a.descriptor, + b.descriptor, + dt_bias.descriptor, + c_float(beta), + c_float(threshold), + ) + ) + + for tensor in [g, beta_output, A_log, a, b, dt_bias]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetFusedGatedDeltaNetGatingWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + check_error( + LIBINFINIOP.infiniopFusedGatedDeltaNetGating( + descriptor, + workspace.data(), + workspace.size(), + g.data(), + beta_output.data(), + A_log.data(), + a.data(), + b.data(), + dt_bias.data(), + None, + ) + ) + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(g.actual_tensor(), g.torch_tensor(), atol=atol, rtol=rtol) + debug( + beta_output.actual_tensor(), + beta_output.torch_tensor(), + atol=atol, + rtol=rtol, + ) + + assert torch.allclose(g.actual_tensor(), g.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose( + beta_output.actual_tensor(), beta_output.torch_tensor(), atol=atol, rtol=rtol + ) + check_error( + LIBINFINIOP.infiniopDestroyFusedGatedDeltaNetGatingDescriptor(descriptor) + ) + + +if __name__ == "__main__": + args = get_args() + DEBUG = args.debug + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + print("\033[92m Test passed! \033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index e43f3a18d..d04b6c138 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -134,6 +134,48 @@ def addcmul_(lib): ] +@OpRegister.operator +def fused_gated_delta_net_gating_(lib): + lib.infiniopCreateFusedGatedDeltaNetGatingDescriptor.restype = c_int32 + lib.infiniopCreateFusedGatedDeltaNetGatingDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + c_float, + ] + + lib.infiniopGetFusedGatedDeltaNetGatingWorkspaceSize.restype = c_int32 + lib.infiniopGetFusedGatedDeltaNetGatingWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopFusedGatedDeltaNetGating.restype = c_int32 + lib.infiniopFusedGatedDeltaNetGating.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyFusedGatedDeltaNetGatingDescriptor.restype = c_int32 + lib.infiniopDestroyFusedGatedDeltaNetGatingDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def cdist_(lib): # 1. 创建描述符接口