diff --git a/ext/VectorInterfaceEnzymeExt.jl b/ext/VectorInterfaceEnzymeExt.jl index a61c775..4d4045f 100644 --- a/ext/VectorInterfaceEnzymeExt.jl +++ b/ext/VectorInterfaceEnzymeExt.jl @@ -3,28 +3,11 @@ module VectorInterfaceEnzymeExt # COV_EXCL_START # Enzyme rules aren't reachable by coverage using VectorInterface +using VectorInterface: project_scalar, project_add! using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules -function project_add!(C, A, α) - TC = Base.promote_op(+, scalartype(A), scalartype(α)) - return if !(TC <: Real) && scalartype(C) <: Real - add!(C, real(add!(zerovector(C, TC), A, α))) - else - add!(C, A, α) - end -end - -""" - project_scalar(x::Number, dx::Number) - -Project a computed tangent `dx` onto the correct tangent type for `x`. -For example, we might compute a complex `dx` but only require the real part. -""" -project_scalar(x::Number, dx::Number) = oftype(x, dx) -project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) - function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(scale!)}, diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index 00ad349..084c7b1 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -1,31 +1,12 @@ module VectorInterfaceMooncakeExt using VectorInterface +using VectorInterface: project_scalar, project_add! using Mooncake using Mooncake: @is_primitive, DefaultCtx, NoFData, NoRData, NoTangent, CoDual, Dual, arrayify, primal, extract -# Projection -# ---------- -""" - project_scalar(x::Number, dx::Number) - -Project a computed tangent `dx` onto the correct tangent type for `x`. -For example, we might compute a complex `dx` but only require the real part. -""" -project_scalar(x::Number, dx::Number) = oftype(x, dx) -project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) - -function project_add!(C, A, α) - TC = Base.promote_op(+, scalartype(A), scalartype(α)) - return if !(TC <: Real) && scalartype(C) <: Real - add!(C, real(add!(zerovector(C, TC), A, α))) - else - add!(C, A, α) - end -end - _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{T}) where {T <: Number} = Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData diff --git a/src/VectorInterface.jl b/src/VectorInterface.jl index df0e684..85088a1 100644 --- a/src/VectorInterface.jl +++ b/src/VectorInterface.jl @@ -48,4 +48,7 @@ include("fallbacks.jl") # Minimal vector type for testing include("minimalvec.jl") +# Common AD helper methods +include("ad.jl") + end diff --git a/src/ad.jl b/src/ad.jl new file mode 100644 index 0000000..e5f86e1 --- /dev/null +++ b/src/ad.jl @@ -0,0 +1,17 @@ +function project_add!(C, A, α) + TC = Base.promote_op(+, scalartype(A), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(add!(zerovector(C, TC), A, α))) + else + add!(C, A, α) + end +end + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))