diff --git a/Project.toml b/Project.toml index 3d0abc9b1..d731bbaa6 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" TensorKitAMDGPUExt = "AMDGPU" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" +TensorKitEnzymeExt = "Enzyme" +TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils" TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitMooncakeExt = "Mooncake" @@ -43,10 +47,12 @@ AMDGPU = "2" CUDA = "6" ChainRulesCore = "1" Dictionaries = "0.4" +Enzyme = "0.13.157" +EnzymeTestUtils = "0.2.8" FiniteDifferences = "0.12" -LRUCache = "1.0.2" +LRUCache = "1.6" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.7" +MatrixAlgebraKit = "0.6.8" Mooncake = "0.5.27" OhMyThreads = "0.8.0" Printf = "1" @@ -54,8 +60,8 @@ Random = "1" ScopedValues = "1.3.0" Strided = "2" TensorKitSectors = "0.3.7" -TensorOperations = "5.5" +TensorOperations = "5.5.2" TupleTools = "1.5" -VectorInterface = "0.4.8, 0.5, 0.6" +VectorInterface = "0.4.8, 0.5" cuTENSOR = "6" julia = "1.10" diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl new file mode 100644 index 000000000..c2be2cdd9 --- /dev/null +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -0,0 +1,16 @@ +module TensorKitEnzymeExt + +using Enzyme +using TensorKit +import TensorKit as TK +using VectorInterface +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using MatrixAlgebraKit +using TupleTools +using Random: AbstractRNG + +include("utility.jl") +include("factorizations.jl") + +end diff --git a/ext/TensorKitEnzymeExt/factorizations.jl b/ext/TensorKitEnzymeExt/factorizations.jl new file mode 100644 index 000000000..765ce694d --- /dev/null +++ b/ext/TensorKitEnzymeExt/factorizations.jl @@ -0,0 +1,275 @@ +# need these due to Enzyme choking on blocks + +for f in (:project_hermitian, :project_antihermitian) + f! = Symbol(f, :!) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + arg::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + $f!(A.val, arg.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + cache = nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + arg::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + if !isa(A, Const) && !isa(arg, Const) + $f!(arg.dval, arg.dval, alg.val) + if A.dval !== arg.dval + A.dval .+= arg.dval + make_zero!(arg.dval) + end + end + return (nothing, nothing, nothing) + end + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret = $f(A.val, alg.val) + dret = make_zero(ret) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + cache = dret + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + dret = cache + if !isa(A, Const) + $f!(dret, dret, alg.val) + add!(A.dval, dret) + end + make_zero!(dret) + return (nothing, nothing) + end + end +end + +for (f, pb) in ( + (:eig_full, :(MatrixAlgebraKit.eig_pullback!)), + (:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)), + (:lq_compact, :(MatrixAlgebraKit.lq_pullback!)), + (:qr_compact, :(MatrixAlgebraKit.qr_pullback!)), + (:lq_full, :(MatrixAlgebraKit.lq_pullback!)), + (:qr_full, :(MatrixAlgebraKit.qr_pullback!)), + (:lq_null, :(MatrixAlgebraKit.lq_null_pullback!)), + (:qr_null, :(MatrixAlgebraKit.qr_null_pullback!)), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret = $f(A.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing + cache = (ret, shadow) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + !isa(A, Const) && $pb(A.dval, A.val, cache...) + return (nothing, nothing) + end + end +end + +for (f, f_full, pb) in ( + (:eig_vals, :eig_full, :(MatrixAlgebraKit.eig_vals_pullback!)), + (:eigh_vals, :eigh_full, :(MatrixAlgebraKit.eigh_vals_pullback!)), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret_full = $f_full(A.val, alg.val) + ret = diagview(ret_full[1]) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing + cache = (ret, shadow, ret_full[2]) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + D, dD, V = cache + !isa(A, Const) && $pb(A.dval, A.val, (DiagonalTensorMap(D), V), dD) + return (nothing, nothing) + end + end +end + +for f in (:svd_compact, :svd_full) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ = $f(A.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? USVᴴ : nothing + shadow = EnzymeRules.needs_shadow(config) ? make_zero(USVᴴ) : nothing + cache = (USVᴴ, shadow) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + !isa(A, Const) && MatrixAlgebraKit.svd_pullback!(A.dval, A.val, cache...) + return (nothing, nothing) + end + end + + # mutating version is not guaranteed to actually mutate + # so we can simply use the non-mutating version instead + f! = Symbol(f, :!) + #=@eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + USVᴴ::Annotation, + alg::Const, + ) where {RT} + EnzymeRules.augmented_primal(func, RT, A, alg) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + USVᴴ::Annotation, + alg::Const, + ) where {RT} + EnzymeRules.reverse(func, RT, A, alg) + end + end=# #hmmmm +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ = svd_compact(A.val, alg.val.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc) + dUSVᴴtrunc = make_zero(USVᴴtrunc) + cache = (USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind) + return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind = cache + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴtrunc, ind) + return (nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_trunc_no_error)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + DV = eig_full(A.val, alg.val.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV, alg.val.trunc) + dDVtrunc = make_zero(DVtrunc) + cache = (DV, DVtrunc, dDVtrunc, ind) + return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_trunc_no_error)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + DV, DVtrunc, dDVtrunc, ind = cache + MatrixAlgebraKit.eig_pullback!(A.dval, A.val, DV, dDVtrunc, ind) + return (nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_trunc_no_error)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + DV = eigh_full(A.val, alg.val.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV, alg.val.trunc) + dDVtrunc = make_zero(DVtrunc) + cache = (DV, DVtrunc, dDVtrunc, ind) + return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_trunc_no_error)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + DV, DVtrunc, dDVtrunc, ind = cache + MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DV, dDVtrunc, ind) + return (nothing, nothing) +end diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl new file mode 100644 index 000000000..a84089b1c --- /dev/null +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -0,0 +1,81 @@ +# Projection +# ---------- +pullback_dα(α::Const, C::Const, A) = nothing +pullback_dα(α::Const, C::Annotation, A) = nothing +pullback_dα(α::Annotation, C::Const, A) = zero(α.val) +pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval)) + +pullback_dβ(β::Const, C::Const, Ccache) = nothing +pullback_dβ(β::Const, C::Annotation, Ccache) = nothing +pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val) +pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval)) + +pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) + end +end + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ + +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true + +# needed for the ising bimodule case +@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing diff --git a/ext/TensorKitEnzymeTestUtilsExt.jl b/ext/TensorKitEnzymeTestUtilsExt.jl new file mode 100644 index 000000000..4a1f393b1 --- /dev/null +++ b/ext/TensorKitEnzymeTestUtilsExt.jl @@ -0,0 +1,66 @@ +module TensorKitEnzymeTestUtilsExt + +using TensorKit +using EnzymeTestUtils +using EnzymeTestUtils: Enzyme +import EnzymeTestUtils: to_vec, from_vec, rand_tangent + +function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x)) + if has_seen || is_const + x_vec = Float32[] + else + vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)] + x_vec, back = to_vec(vec_of_vecs) + seen_vecs[x] = x_vec + end + function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return seen_xs[x] + is_const && return x + + x_new = similar(x) + xvec_of_vecs = back(x_vec_new) + for (i, (c, b)) in enumerate(blocks(x_new)) + scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c)) + end + if Core.Typeof(x_new) != Core.Typeof(x) + x_new = Core.Typeof(x)(x_new) + end + seen_xs[x] = x_new + return x_new + end + return x_vec, TensorMap_from_vec +end +function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(parent(t), seen_vecs) + return parent_vec, adjoint ∘ parent_t +end +function EnzymeTestUtils.to_vec(t::TensorKit.DiagonalTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(TensorMap(t), seen_vecs) + return parent_vec, TensorKit.DiagonalTensorMap ∘ parent_t +end + +# generate random tangents for testing +function EnzymeTestUtils.rand_tangent(rng, t::TensorMap) + return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t)) +end + +function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap) + return adjoint(rand_tangent(rng, parent(t))) +end + +function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap) + return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1)) +end + +function EnzymeTestUtils.map_fields_recursive(f::typeof(Base.copyto!), y::TensorKit.SortedVectorDict{K, V}, x::TensorKit.SortedVectorDict{K, V}) where {K, V} + copyto!(y.keys, x.keys) + copyto!(y.values, x.values) + return y +end + +end diff --git a/test/Project.toml b/test/Project.toml index 18af8af80..5252ff1f4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/test/enzyme-factorizations-eig/eig.jl b/test/enzyme-factorizations-eig/eig.jl new file mode 100644 index 000000000..1e910b8c9 --- /dev/null +++ b/test/enzyme-factorizations-eig/eig.jl @@ -0,0 +1,47 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +function remove_eiggauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = V' * ΔV + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end + +@timedtestset "Enzyme - Factorizations (EIG): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + DV = eig_full(t) + ΔDV = EnzymeTestUtils.rand_tangent(DV) + remove_eiggauge_dependence!(ΔDV[2], DV...) + EnzymeTestUtils.test_reverse(eig_full, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol) + + #D = eig_vals(t) + #EnzymeTestUtils.test_reverse(eig_vals, Duplicated, (t, Duplicated); atol, rtol) + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(eig_trunc_no_error, t, nothing; trunc) + DVtrunc = eig_trunc_no_error(t, alg) + ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc) + remove_eiggauge_dependence!(ΔDVtrunc[2], DVtrunc...) + EnzymeTestUtils.test_reverse(eig_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol) +end diff --git a/test/enzyme-factorizations-eig/eigh.jl b/test/enzyme-factorizations-eig/eigh.jl new file mode 100644 index 000000000..8cbe425c2 --- /dev/null +++ b/test/enzyme-factorizations-eig/eigh.jl @@ -0,0 +1,49 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float32,) # ComplexF64) + +function remove_eighgauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = project_antihermitian!(V' * ΔV) + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +@timedtestset "Enzyme - Factorizations (EIGH): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + th = project_hermitian(t) + DV = eigh_full(th) + ΔDV = EnzymeTestUtils.rand_tangent(DV) + remove_eighgauge_dependence!(ΔDV[2], DV...) + proj_eigh_full(t) = eigh_full(project_hermitian(t)) + EnzymeTestUtils.test_reverse(proj_eigh_full, Duplicated, (th, Duplicated); output_tangent = ΔDV, atol, rtol) + + #D = eigh_vals(th) + #EnzymeTestUtils.test_reverse(eigh_vals ∘ project_hermitian, Duplicated, (th, Duplicated); atol, rtol) + + V_trunc = spacetype(th)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(eigh_trunc_no_error, th, nothing; trunc) + DVtrunc = eigh_trunc_no_error(th, alg) + ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc) + remove_eighgauge_dependence!(ΔDVtrunc[2], DVtrunc...) + proj_eigh(t, alg) = eigh_trunc_no_error(project_hermitian(t), alg) + EnzymeTestUtils.test_reverse(proj_eigh, Duplicated, (th, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol) +end diff --git a/test/enzyme-factorizations-eig/projections.jl b/test/enzyme-factorizations-eig/projections.jl new file mode 100644 index 000000000..1b5b3a28f --- /dev/null +++ b/test/enzyme-factorizations-eig/projections.jl @@ -0,0 +1,18 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Factorizations (PROJECTIONS): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + EnzymeTestUtils.test_reverse(project_hermitian, Duplicated, (t, Duplicated); atol, rtol) + EnzymeTestUtils.test_reverse(project_antihermitian, Duplicated, (t, Duplicated); atol, rtol) + EnzymeTestUtils.test_reverse(project_hermitian!, Duplicated, (t, Duplicated); atol, rtol) + EnzymeTestUtils.test_reverse(project_antihermitian!, Duplicated, (t, Duplicated); atol, rtol) +end diff --git a/test/enzyme-factorizations-lqqr/lq.jl b/test/enzyme-factorizations-lqqr/lq.jl new file mode 100644 index 000000000..9623cdf62 --- /dev/null +++ b/test/enzyme-factorizations-lqqr/lq.jl @@ -0,0 +1,37 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_lq_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +function remove_lq_null_gauge_dependence!(ΔNᴴ, Q) + for (c, b) in blocks(ΔNᴴ) + Qc = block(Q, c) + ΔNᴴQᴴ = b * Qc' + mul!(b, ΔNᴴQᴴ, Qc) + end + return ΔNᴴ +end + +@timedtestset "Enzyme - Factorizations (LQ): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, A in (randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')) + atol = default_tol(T) + rtol = default_tol(T) + EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol) + + # lq_full/lq_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = EnzymeTestUtils.rand_tangent(LQ) + remove_lq_gauge_dependence!(ΔLQ[1], ΔLQ[2], A, LQ[1], LQ[2]) + EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol) + + Nᴴ = lq_null(A) + Q = lq_compact(A)[2] + ΔNᴴ = EnzymeTestUtils.rand_tangent(Nᴴ) + remove_lq_null_gauge_dependence!(ΔNᴴ, Q) + EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); output_tangent = ΔNᴴ, atol, rtol) +end diff --git a/test/enzyme-factorizations-lqqr/qr.jl b/test/enzyme-factorizations-lqqr/qr.jl new file mode 100644 index 000000000..880c332c1 --- /dev/null +++ b/test/enzyme-factorizations-lqqr/qr.jl @@ -0,0 +1,25 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_qr_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Factorizations (QR): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, A in (randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]), randn(T, V[1] ⊗ V[2] ⊗ V[3] ← (V[4] ⊗ V[5])')) + atol = default_tol(T) + rtol = default_tol(T) + + EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = EnzymeTestUtils.rand_tangent(QR) + remove_qr_gauge_dependence!(ΔQR[1], ΔQR[2], A, QR[1], QR[2]) + EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol) + #EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol) +end diff --git a/test/enzyme-factorizations-svd/svd.jl b/test/enzyme-factorizations-svd/svd.jl new file mode 100644 index 000000000..5bcf72cb9 --- /dev/null +++ b/test/enzyme-factorizations-svd/svd.jl @@ -0,0 +1,36 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using MatrixAlgebraKit +using MatrixAlgebraKit: remove_svd_gauge_dependence! +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Factorizations (SVD): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')) + atol = default_tol(T) + rtol = default_tol(T) + + #S = svd_vals(t) + #EnzymeTestUtils.test_reverse(svd_vals, Duplicated, (t, Duplicated); atol, rtol) + + USVᴴ = svd_compact(t) + ΔUSVᴴ = EnzymeTestUtils.rand_tangent.(USVᴴ) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_compact, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol) + + USVᴴ = svd_full(t) + ΔUSVᴴ = (TensorMap(randn!(similar(USVᴴ[1].data)), space(USVᴴ[1])), TensorMap(randn!(similar(USVᴴ[2].data)), space(USVᴴ[2])), TensorMap(randn!(similar(USVᴴ[3].data)), space(USVᴴ[3]))) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_full, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol) + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(svd_trunc_no_error, t, nothing; trunc) + USVᴴtrunc = svd_trunc_no_error(t, alg) + ΔUSVᴴtrunc = EnzymeTestUtils.rand_tangent(USVᴴtrunc) + remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], USVᴴtrunc...) + EnzymeTestUtils.test_reverse(svd_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol) +end