From 43e1bf869e989dc8a94248116a075b216cc5bc1d Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 17 Jun 2026 11:52:08 -0400 Subject: [PATCH 1/2] Unwrap PermutedDimsArray sources in bipermutedimsopadd! Adds a bipermutedimsopadd! method for PermutedDimsArray sources that rewrites each entry of the bipartitioned permutation through the wrapper's own permutation and recurses on parent(src), so the underlying array dispatches its own bipermutedimsopadd! rather than the generic body falling through on the wrapper. Also adds permuteddims(a, perm), defaulting to PermutedDimsArray and overloadable downstream, replacing the FunctionImplementations.permuteddims used internally. Co-Authored-By: Claude Opus 4.8 --- Project.toml | 2 +- src/permutedimsadd.jl | 24 +++++++++++++++++++++++- test/test_permutedimsadd.jl | 20 ++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 965514d..b0e318a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.9.5" +version = "0.9.6" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 8d4f085..7d927d7 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -1,7 +1,14 @@ -using FunctionImplementations: permuteddims using Strided: Strided using StridedViews: StridedViews as SV +""" + permuteddims(a, perm) + +Lazily permute the dimensions of `a` by `perm`. Defaults to `PermutedDimsArray`; +overload for array types with a more efficient lazy permuted representation. +""" +permuteddims(a::AbstractArray, perm) = PermutedDimsArray(a, perm) + # Specify if an array is on CPU. This is helpful for backends that don't support # operations on GPU, such as Strided.jl. iscpu(::AbstractArray) = true @@ -84,6 +91,21 @@ function bipermutedimsopadd!( return dest end +_permuteddims_perm(::PermutedDimsArray{<:Any, <:Any, perm}) where {perm} = perm + +function bipermutedimsopadd!( + dest::AbstractArray, op, src::PermutedDimsArray, + perm_codomain, perm_domain, + α::Number, β::Number + ) + w = _permuteddims_perm(src) + return bipermutedimsopadd!( + dest, op, parent(src), + map(j -> w[j], perm_codomain), map(j -> w[j], perm_domain), + α, β + ) +end + # ---------------------------------------------------------------------------- # # permutedimsopadd! — flat-permutation interface # ---------------------------------------------------------------------------- # diff --git a/test/test_permutedimsadd.jl b/test/test_permutedimsadd.jl index c098129..38fd6d8 100644 --- a/test/test_permutedimsadd.jl +++ b/test/test_permutedimsadd.jl @@ -47,6 +47,26 @@ using Test: @test, @testset @test b′ ≈ β * b + α * permutedims(a, perm) end end + @testset "bipermutedimsopadd! unwraps PermutedDimsArray src (arraytype=$arrayt)" for arrayt in + ( + Array, + JLArray, + ) + dev = adapt(arrayt) + parent = dev(randn(2, 3, 4, 5)) + w = (3, 1, 4, 2) + src = PermutedDimsArray(parent, w) + for (pc, pd) in (((1, 2, 3, 4), ()), ((2, 4), (1, 3)), ((3, 1), (2, 4))) + perm = (pc..., pd...) + ref = permutedims(permutedims(parent, w), perm) + for β in (0, 3) + dest = dev(randn(size(ref)...)) + dest′ = copy(dest) + bipermutedimsopadd!(dest′, identity, src, pc, pd, 2, β) + @test dest′ ≈ β * dest + 2 * ref + end + end + end @testset "bipermutedimsopadd! 0-dim with β=0 must not read dest (eltype=$T)" for T in ( Float64, From da0f4a68f50a87cd9cb8f08f5f20abee5a391f7d Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 17 Jun 2026 15:10:02 -0400 Subject: [PATCH 2/2] Keep using FunctionImplementations.permuteddims in the generic body permuteddims is a FunctionImplementations verb that backends overload to supply a more efficient lazy permute (KroneckerArrays permutes each Kronecker factor and recombines). Owning it in TensorAlgebra and calling that owned version in the bipermutedimsopadd! body bypassed those overloads, so a KroneckerArray reached the generic body wrapped in a plain PermutedDimsArray instead of its native permute. Import it from FunctionImplementations again so the body respects the overloads. The PermutedDimsArray unwrap method is unaffected. Co-Authored-By: Claude Opus 4.8 --- src/permutedimsadd.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 7d927d7..2bbff57 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -1,14 +1,7 @@ +using FunctionImplementations: permuteddims using Strided: Strided using StridedViews: StridedViews as SV -""" - permuteddims(a, perm) - -Lazily permute the dimensions of `a` by `perm`. Defaults to `PermutedDimsArray`; -overload for array types with a more efficient lazy permuted representation. -""" -permuteddims(a::AbstractArray, perm) = PermutedDimsArray(a, perm) - # Specify if an array is on CPU. This is helpful for backends that don't support # operations on GPU, such as Strided.jl. iscpu(::AbstractArray) = true