diff --git a/Project.toml b/Project.toml index 965514d..b0e318a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.9.5" +version = "0.9.6" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 8d4f085..2bbff57 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -84,6 +84,21 @@ function bipermutedimsopadd!( return dest end +_permuteddims_perm(::PermutedDimsArray{<:Any, <:Any, perm}) where {perm} = perm + +function bipermutedimsopadd!( + dest::AbstractArray, op, src::PermutedDimsArray, + perm_codomain, perm_domain, + α::Number, β::Number + ) + w = _permuteddims_perm(src) + return bipermutedimsopadd!( + dest, op, parent(src), + map(j -> w[j], perm_codomain), map(j -> w[j], perm_domain), + α, β + ) +end + # ---------------------------------------------------------------------------- # # permutedimsopadd! — flat-permutation interface # ---------------------------------------------------------------------------- # diff --git a/test/test_permutedimsadd.jl b/test/test_permutedimsadd.jl index c098129..38fd6d8 100644 --- a/test/test_permutedimsadd.jl +++ b/test/test_permutedimsadd.jl @@ -47,6 +47,26 @@ using Test: @test, @testset @test b′ ≈ β * b + α * permutedims(a, perm) end end + @testset "bipermutedimsopadd! unwraps PermutedDimsArray src (arraytype=$arrayt)" for arrayt in + ( + Array, + JLArray, + ) + dev = adapt(arrayt) + parent = dev(randn(2, 3, 4, 5)) + w = (3, 1, 4, 2) + src = PermutedDimsArray(parent, w) + for (pc, pd) in (((1, 2, 3, 4), ()), ((2, 4), (1, 3)), ((3, 1), (2, 4))) + perm = (pc..., pd...) + ref = permutedims(permutedims(parent, w), perm) + for β in (0, 3) + dest = dev(randn(size(ref)...)) + dest′ = copy(dest) + bipermutedimsopadd!(dest′, identity, src, pc, pd, 2, β) + @test dest′ ≈ β * dest + 2 * ref + end + end + end @testset "bipermutedimsopadd! 0-dim with β=0 must not read dest (eltype=$T)" for T in ( Float64,