Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "ops/flash_attention.hpp"
#include "ops/fmin.hpp"
#include "ops/fmod.hpp"
#include "ops/fused_gated_delta_net_gating.hpp"
#include "ops/gelu.hpp"
#include "ops/gelutanh.hpp"
#include "ops/hardswish.hpp"
Expand Down
37 changes: 37 additions & 0 deletions include/infinicore/ops/fused_gated_delta_net_gating.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

#include <utility>

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(FusedGatedDeltaNetGating,
Tensor,
Tensor,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
float,
float);

std::pair<Tensor, Tensor> fused_gated_delta_net_gating(const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta = 1.0f,
float threshold = 20.0f);

void fused_gated_delta_net_gating_(Tensor g,
Tensor beta_output,
const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta = 1.0f,
float threshold = 20.0f);

} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "infiniop/ops/floor_divide.h"
#include "infiniop/ops/fmin.h"
#include "infiniop/ops/fmod.h"
#include "infiniop/ops/fused_gated_delta_net_gating.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gelutanh.h"
#include "infiniop/ops/gemm.h"
Expand Down
43 changes: 43 additions & 0 deletions include/infiniop/ops/fused_gated_delta_net_gating.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef __INFINIOP_FUSED_GATED_DELTA_NET_GATING_API_H__
#define __INFINIOP_FUSED_GATED_DELTA_NET_GATING_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopFusedGatedDeltaNetGatingDescriptor_t;

__INFINI_C __export infiniStatus_t
infiniopCreateFusedGatedDeltaNetGatingDescriptor(
infiniopHandle_t handle,
infiniopFusedGatedDeltaNetGatingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t g_desc,
infiniopTensorDescriptor_t beta_output_desc,
infiniopTensorDescriptor_t A_log_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t dt_bias_desc,
float beta,
float threshold);

__INFINI_C __export infiniStatus_t
infiniopGetFusedGatedDeltaNetGatingWorkspaceSize(
infiniopFusedGatedDeltaNetGatingDescriptor_t desc,
size_t *size);

__INFINI_C __export infiniStatus_t
infiniopFusedGatedDeltaNetGating(
infiniopFusedGatedDeltaNetGatingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *g,
void *beta_output,
const void *A_log,
const void *a,
const void *b,
const void *dt_bias,
void *stream);

__INFINI_C __export infiniStatus_t
infiniopDestroyFusedGatedDeltaNetGatingDescriptor(
infiniopFusedGatedDeltaNetGatingDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .causal_softmax import causal_softmax
from .embedding import embedding
from .flash_attention import flash_attention
from .fused_gated_delta_net_gating import fused_gated_delta_net_gating
from .gaussian_nll_loss import gaussian_nll_loss
from .hardswish import hardswish
from .hardtanh import hardtanh
Expand Down Expand Up @@ -43,6 +44,7 @@
"causal_softmax",
"embedding",
"flash_attention",
"fused_gated_delta_net_gating",
"gaussian_nll_loss",
"interpolate",
"linear",
Expand Down
37 changes: 37 additions & 0 deletions python/infinicore/nn/functional/fused_gated_delta_net_gating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def fused_gated_delta_net_gating(
A_log: Tensor,
a: Tensor,
b: Tensor,
dt_bias: Tensor,
beta: float = 1.0,
threshold: float = 20.0,
*,
out: tuple[Tensor, Tensor] | None = None,
) -> tuple[Tensor, Tensor]:
if out is None:
g, beta_output = _infinicore.fused_gated_delta_net_gating(
A_log._underlying,
a._underlying,
b._underlying,
dt_bias._underlying,
beta,
threshold,
)
return Tensor(g), Tensor(beta_output)

g, beta_output = out
_infinicore.fused_gated_delta_net_gating_(
g._underlying,
beta_output._underlying,
A_log._underlying,
a._underlying,
b._underlying,
dt_bias._underlying,
beta,
threshold,
)
return g, beta_output
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "infinicore/ops/fused_gated_delta_net_gating.hpp"

#include "../../utils.hpp"

#include <stdexcept>

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FusedGatedDeltaNetGating);

FusedGatedDeltaNetGating::FusedGatedDeltaNetGating(Tensor g,
Tensor beta_output,
const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta,
float threshold) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(g, beta_output, A_log, a, b, dt_bias);
INFINICORE_GRAPH_OP_DISPATCH(g->device().getType(), g, beta_output, A_log, a, b, dt_bias, beta, threshold);
}

void FusedGatedDeltaNetGating::execute(Tensor g,
Tensor beta_output,
const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta,
float threshold) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(FusedGatedDeltaNetGating, g, beta_output, A_log, a, b, dt_bias, beta, threshold);
}

static void validate_inputs(const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias) {
if (a->shape().size() != 3 || b->shape().size() != 3) {
throw std::runtime_error("fused_gated_delta_net_gating expects a and b with shape [batch_size, seq_len, hidden]");
}
if (a->shape() != b->shape()) {
throw std::runtime_error("fused_gated_delta_net_gating expects a and b to have the same shape");
}
if (A_log->shape().size() != 1 || dt_bias->shape().size() != 1) {
throw std::runtime_error("fused_gated_delta_net_gating expects A_log and dt_bias with shape [hidden]");
}
if (A_log->shape()[0] != a->shape()[2] || dt_bias->shape()[0] != a->shape()[2]) {
throw std::runtime_error("fused_gated_delta_net_gating hidden dimension mismatch");
}
}

std::pair<Tensor, Tensor> fused_gated_delta_net_gating(const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta,
float threshold) {
validate_inputs(A_log, a, b, dt_bias);

Tensor g = Tensor::empty(a->shape(), DataType::F32, a->device());
Tensor beta_output = Tensor::empty(a->shape(), DataType::F32, a->device());
fused_gated_delta_net_gating_(g, beta_output, A_log, a, b, dt_bias, beta, threshold);
return {g, beta_output};
}

void fused_gated_delta_net_gating_(Tensor g,
Tensor beta_output,
const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta,
float threshold) {
validate_inputs(A_log, a, b, dt_bias);
if (g->shape() != a->shape() || beta_output->shape() != a->shape()) {
throw std::runtime_error("fused_gated_delta_net_gating_ expects outputs with shape [batch_size, seq_len, hidden]");
}
if (g->dtype() != DataType::F32 || beta_output->dtype() != DataType::F32) {
throw std::runtime_error("fused_gated_delta_net_gating_ expects float32 outputs");
}

FusedGatedDeltaNetGating::execute(g, beta_output, A_log, a, b, dt_bias, beta, threshold);
}

} // namespace infinicore::op
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include "infinicore/ops/fused_gated_delta_net_gating.hpp"

#include "../infiniop_impl.hpp"

namespace infinicore::op::fused_gated_delta_net_gating_impl::infiniop {

INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FusedGatedDeltaNetGating, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, g, beta_output, A_log, a, b, dt_bias;
};

void *plan(Tensor g,
Tensor beta_output,
const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta,
float threshold) {
size_t seed = hash_combine(g, beta_output, A_log, a, b, dt_bias, beta, threshold);

INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, FusedGatedDeltaNetGating, seed,
g->desc(),
beta_output->desc(),
A_log->desc(),
a->desc(),
b->desc(),
dt_bias->desc(),
beta,
threshold);

INFINIOP_WORKSPACE_TENSOR(workspace, FusedGatedDeltaNetGating, descriptor);

return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(g),
graph::GraphTensor(beta_output),
graph::GraphTensor(A_log),
graph::GraphTensor(a),
graph::GraphTensor(b),
graph::GraphTensor(dt_bias)};
}

void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);

INFINICORE_CHECK_ERROR(infiniopFusedGatedDeltaNetGating(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->g->data(),
planned->beta_output->data(),
planned->A_log->data(),
planned->a->data(),
planned->b->data(),
planned->dt_bias->data(),
context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(FusedGatedDeltaNetGating, &plan, &run, &cleanup);

} // namespace infinicore::op::fused_gated_delta_net_gating_impl::infiniop
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "ops/floor_divide.hpp"
#include "ops/fmin.hpp"
#include "ops/fmod.hpp"
#include "ops/fused_gated_delta_net_gating.hpp"
#include "ops/gaussian_nll_loss.hpp"
#include "ops/hardswish.hpp"
#include "ops/hardtanh.hpp"
Expand Down Expand Up @@ -160,6 +161,7 @@ inline void bind(py::module &m) {
bind_hinge_embedding_loss(m);
bind_kv_caching(m);
bind_fmod(m);
bind_fused_gated_delta_net_gating(m);
bind_fmin(m);
bind_cat(m);
bind_causal_softmax(m);
Expand Down
44 changes: 44 additions & 0 deletions src/infinicore/pybind11/ops/fused_gated_delta_net_gating.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/fused_gated_delta_net_gating.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_fused_gated_delta_net_gating(py::module &m) {
m.def(
"fused_gated_delta_net_gating",
[](const Tensor &A_log,
const Tensor &a,
const Tensor &b,
const Tensor &dt_bias,
float beta,
float threshold) {
auto result = op::fused_gated_delta_net_gating(A_log, a, b, dt_bias, beta, threshold);
return py::make_tuple(result.first, result.second);
},
py::arg("A_log"),
py::arg("a"),
py::arg("b"),
py::arg("dt_bias"),
py::arg("beta") = 1.0f,
py::arg("threshold") = 20.0f,
R"doc(Fused GatedDeltaNet gating out-of-place.)doc");

m.def("fused_gated_delta_net_gating_",
&op::fused_gated_delta_net_gating_,
py::arg("g"),
py::arg("beta_output"),
py::arg("A_log"),
py::arg("a"),
py::arg("b"),
py::arg("dt_bias"),
py::arg("beta") = 1.0f,
py::arg("threshold") = 20.0f,
R"doc(Fused GatedDeltaNet gating writing to provided outputs.)doc");
}

} // namespace infinicore::ops
Loading
Loading