Skip to content
Open
12 changes: 11 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, geev!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!, eig_pullback!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!, eig_pullback!, svd_pushforward!
using CUDA, CUDA.cuBLAS
using CUDA: i32
using LinearAlgebra
Expand Down Expand Up @@ -213,4 +213,14 @@ function eig_pullback!(ΔA::AnyCuMatrix, A, DV, ΔDV, ind::AnyCuVector; kwargs..
return eig_pullback!(ΔA, A, DV, ΔDV, collect(ind); kwargs...)
end

# have to override this as methods are missing in GPUArrays for the various
# views of Diagonal of ΔA
function svd_pushforward!(
ΔA::Diagonal{T, <:CuVector{T}}, A, USVᴴ, ΔUSVᴴ, ind = Colon();
rank_atol::Real = MatrixAlgebraKit.default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = MatrixAlgebraKit.default_pullback_rank_atol(USVᴴ[2])
) where {T}
return MatrixAlgebraKit.svd_pushforward!(diagm(diagview(ΔA)), A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
end

end
56 changes: 56 additions & 0 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_vals_pushforward!, eigh_vals_pushforward!
using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using Enzyme
Expand Down Expand Up @@ -264,6 +265,34 @@ for f in (:svd_compact!, :svd_full!)
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
return (nothing, nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{TA},
USVᴴ::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
$f(A.val, USVᴴ.val, alg.val)
if !isa(A, Const)
if $(f == svd_compact!)
make_zero!(USVᴴ.dval[2].diag)
else
make_zero!(USVᴴ.dval[2])
end
!isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval)
make_zero!(A.dval)
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return USVᴴ
elseif EnzymeRules.needs_primal(config)
return USVᴴ.val
elseif EnzymeRules.needs_shadow(config)
return USVᴴ.dval
else
return nothing
end
end
end
end

Expand Down Expand Up @@ -502,5 +531,32 @@ function EnzymeRules.reverse(
!isa(S, Const) && !A_is_arg && make_zero!(S.dval)
return (nothing, nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(svd_vals!)},
::Type{RT},
A::Annotation{TA},
S::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval
U, S_, Vᴴ = svd_compact!(A.val, alg.val)
if !isa(A, Const) && !isa(S, Const)
ΔS = A_is_arg ? make_zero(S.dval) : S.dval
svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(diagview(S_)), Vᴴ), ΔS)
A_is_arg && (S.dval .= ΔS)
end
!A_is_arg && make_zero!(A.dval)
copyto!(S.val, diagview(S_))
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return S
elseif EnzymeRules.needs_primal(config)
return S.val
elseif EnzymeRules.needs_shadow(config)
return S.dval
else
return nothing
end
end

end
52 changes: 48 additions & 4 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pul
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward!
using MatrixAlgebraKit: TruncatedAlgorithm
using LinearAlgebra

Expand Down Expand Up @@ -538,7 +539,7 @@ for (f!, f) in (
(:svd_compact!, :svd_compact),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
Expand All @@ -562,7 +563,18 @@ for (f!, f) in (
end
return USVᴴ_dUSVᴴ, svd_adjoint
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual)
A, dA = arrayify(A_dA)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
$f!(A, USVᴴ, Mooncake.primal(alg_dalg))
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
return USVᴴ_dUSVᴴ
end
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
Expand All @@ -585,10 +597,23 @@ for (f!, f) in (
end
return USVᴴ_codual, svd_adjoint
end
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
A, dA = arrayify(A_dA)
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
dUSVᴴ = Mooncake.zero_tangent(USVᴴ)
USVᴴ_dual = Dual(USVᴴ, dUSVᴴ)
U, S, Vᴴ = Mooncake.primal(USVᴴ_dual)
dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_dual)
U, dU = arrayify(U, dU_)
S, dS = arrayify(S, dS_)
Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_)
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
return USVᴴ_dual
end
end
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -604,8 +629,17 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
end
return S_dS, svd_vals_adjoint
end
function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
S, dS = arrayify(S_dS)
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
copy!(S, diagview(USVᴴ[2]))
svd_vals_pushforward!(dA, A, USVᴴ, dS)
return S_dS
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -624,6 +658,16 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
end
return S_codual, svd_vals_adjoint
end
function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
S = diagview(USVᴴ[2])
S_dual = Dual(S, Mooncake.zero_tangent(S))
S_, dS = arrayify(S_dual)
svd_vals_pushforward!(dA, A, USVᴴ, dS)
return S_dual
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
Expand Down
3 changes: 2 additions & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
using LinearAlgebra: sylvester, lu!, diagm
using LinearAlgebra: isposdef, issymmetric
using LinearAlgebra: Diagonal, Hermitian, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
using LinearAlgebra: UpperTriangular, LowerTriangular, UniformScaling
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt

export isisometric, isunitary, ishermitian, isantihermitian
Expand Down Expand Up @@ -132,6 +132,7 @@ include("pullbacks/polar.jl")
include("pushforwards/polar.jl")
include("pushforwards/eig.jl")
include("pushforwards/eigh.jl")
include("pushforwards/svd.jl")

include("precompile.jl")

Expand Down
7 changes: 6 additions & 1 deletion src/pushforwards/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ function eig_pushforward!(
end
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
mul!(ΔV, V, ∂K, 1, 0)
mul!(ΔV, V, ∂K)
if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, V; dims = 1)
infinitesimal_phases = imag.(ΔV[I] ./ V[I])
ΔV .-= im .* V .* infinitesimal_phases
end
end
return ΔDV
end
Expand Down
5 changes: 5 additions & 0 deletions src/pushforwards/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ function eigh_pushforward!(
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
ΔV = mul!(ΔV, V, ∂K)
if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, V; dims = 1)
infinitesimal_phases = imag.(ΔV[I] ./ V[I])
ΔV .-= im .* V .* infinitesimal_phases
end
end
return (ΔD, ΔV)
end
Expand Down
77 changes: 77 additions & 0 deletions src/pushforwards/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = default_pullback_rank_atol(A), kwargs...)
U, Smat, Vᴴ = USVᴴ
m, n = size(U, 1), size(Vᴴ, 2)
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
minmn = min(m, n)
S = diagview(Smat)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
r = svd_rank(S; rank_atol)

U₁ = view(U, :, 1:r)
S₁ = view(S, 1:r)
V₁ᴴ = view(Vᴴ, 1:r, :)

# compact region
V₁ = adjoint(V₁ᴴ)
ΔAV₁ = ΔA * V₁
UᴴΔAV₁ = U₁' * ΔAV₁
if !iszerotangent(ΔS)
ΔS₁ = view(diagview(ΔS), 1:r)
ΔS₁ .= real.(diagview(UᴴΔAV₁))
end
if !iszerotangent(ΔU) || !iszerotangent(ΔVᴴ)
hUᴴΔAV₁ = inv_safe.(transpose(S₁) .- S₁) .* project_hermitian(UᴴΔAV₁)
Comment thread
kshyatt marked this conversation as resolved.
aUᴴΔAV₁ = inv_safe.(transpose(S₁) .+ S₁) .* project_antihermitian(UᴴΔAV₁)
Comment on lines +23 to +24

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think below only the sum and difference are actually used, we could use a kernel like

function _avgdiff!(A::AbstractArray, B::AbstractArray)
axes(A) == axes(B) || throw(DimensionMismatch())
@simd for I in eachindex(A, B)
@inbounds begin
a = A[I]
b = B[I]
A[I] = (a + b) / 2
B[I] = b - a
end
end
return A, B
end
to avoid the two extra allocations, but I'm also happy to just leave them as-is, it's hard to imagine this really making that huge of a difference

if !iszerotangent(ΔU)
ΔU₁ = view(ΔU, :, 1:r)
K̇ = hUᴴΔAV₁ + aUᴴΔAV₁
mul!(ΔU₁, U₁, K̇)
if m > r
ΔAV₁ = mul!(ΔAV₁, U₁, UᴴΔAV₁, -1, 1)
ΔU₁ .+= ΔAV₁ ./ transpose(S₁)
end
if size(U, 2) > r # these columns of U are undetermined, but U' * U̇ should be antihermitian
U₂ = view(U, :, (r + 1):size(U, 2))
ΔU₁ᴴU₂ = ΔU₁' * U₂
ΔU₂ = view(ΔU, :, (r + 1):size(U, 2))
mul!(ΔU₂, U₁, ΔU₁ᴴU₂, -1, 0)
end
end
if !iszerotangent(ΔVᴴ)
ΔV₁ᴴ = view(ΔVᴴ, 1:r, :)
Ṁ = hUᴴΔAV₁ - aUᴴΔAV₁
mul!(ΔV₁ᴴ, Ṁ', V₁ᴴ)
if n > r
UᴴΔA₁ = U₁' * ΔA
UᴴΔA₁ = mul!(UᴴΔA₁, UᴴΔAV₁, V₁ᴴ, -1, 1)
ΔV₁ᴴ .+= S₁ .\ UᴴΔA₁
end
if size(Vᴴ, 1) > r # these rows of Vᴴ are undetermined, but V * V̇ should be antihermitian
V₂ᴴ = view(Vᴴ, (r + 1):size(Vᴴ, 1), :)
V₂ᴴΔV₁ = V₂ᴴ * ΔV₁ᴴ'
ΔV₂ᴴ = view(ΔVᴴ, (r + 1):size(Vᴴ, 1), :)
mul!(ΔV₂ᴴ, V₂ᴴΔV₁, V₁ᴴ, -1, 0)
end
end
if eltype(U) <: Complex && !iszerotangent(ΔU) && !iszerotangent(ΔVᴴ) # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, U₁; dims = 1)
infinitesimal_phases = imag.(ΔU₁[I] ./ U₁[I])
ΔU₁ .-= im .* U₁ .* infinitesimal_phases
ΔV₁ᴴ .+= im .* transpose(infinitesimal_phases) .* V₁ᴴ
end
end
return (ΔU, ΔS, ΔVᴴ)
end

# TODO
#=function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...)
end=#

function svd_vals_pushforward!(
ΔA, A, USVᴴ, ΔS, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
)
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
end
4 changes: 2 additions & 2 deletions test/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
if !is_buildkite
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
Loading
Loading