Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "StridedViews"
uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
version = "0.5.1"
version = "0.5.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand All @@ -28,6 +29,7 @@ CUDACore = "6"
JLArrays = "0.3.1"
LinearAlgebra = "1"
Metal = "1"
PrecompileTools = "1.1"
PtrArrays = "1.2.0"
julia = "1.10"

Expand Down
1 change: 1 addition & 0 deletions src/StridedViews.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ export StridedView, sreshape, sview, isstrided

include("auxiliary.jl")
include("stridedview.jl")
include("precompile.jl")

end
42 changes: 42 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Precompilation workload
# ------------------------
# Cache the core `StridedView` specializations for the BLAS element types over a small
# range of dimensionalities and the op-wrappers a `StridedView` can carry. These are the
# specializations that downstream packages (e.g. TensorOperations / Strided) hit on their
# first call, so warming them here removes that first-call latency.
using PrecompileTools: @setup_workload, @compile_workload

@setup_workload begin
@compile_workload begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
# construction + property queries + core ops for ndims 1:4
for N in 1:6
A = Array{T, N}(undef, ntuple(_ -> 2, N))
sv = StridedView(A)
size(sv)
strides(sv)
offset(sv)
csv = conj(sv)
# permute through the identity permutation (exercises the per-N path)
permutedims(sv, ntuple(identity, N))
permutedims(csv, ntuple(identity, N))
# reshape to a flat vector and back (also exercises sview on the flat view)
flat = sreshape(sv, (length(sv),))
sview(flat, 1:length(sv))
getindex(sv, ntuple(_ -> 1, N)...)
flat = sreshape(csv, (length(sv),))
sview(flat, 1:length(csv))
getindex(csv, ntuple(_ -> 1, N)...)
end
# 2D matrix wrappers: transpose / adjoint
M = Array{T, 2}(undef, 2, 2)
svM = StridedView(M)
transpose(svM)
adjoint(svM)
# a representative 4D slice view (the SliceIndex `getindex` construction path)
A4 = Array{T, 4}(undef, 2, 2, 2, 2)
sv4 = StridedView(A4)
getindex(sv4, :, 1:2, 1, 1:2)
end
end
end
15 changes: 7 additions & 8 deletions src/stridedview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ end
return a
end

# Indexing with slice indices to create a new view
@inline function Base.getindex(a::StridedView{T, N}, I::Vararg{SliceIndex, N}) where {T, N}
# Indexing with slice indices to create a new view.
function Base.getindex(a::StridedView{T, N}, I::Vararg{SliceIndex, N}) where {T, N}
return StridedView{T}(
a.parent,
_computeviewsize(a.size, I),
Expand Down Expand Up @@ -179,7 +179,7 @@ function Base.conj(a::StridedView)
return StridedView{T}(a.parent, a.size, a.strides, a.offset, newop)
end

@inline function Base.permutedims(a::StridedView{T, N}, p) where {T, N}
function Base.permutedims(a::StridedView{T, N}, p) where {T, N}
_isperm(N, p) || throw(ArgumentError("Invalid permutation of length $N: $p"))
newsize = ntuple(n -> size(a, p[n]), Val(N))
newstrides = ntuple(n -> stride(a, p[n]), Val(N))
Expand Down Expand Up @@ -228,10 +228,10 @@ sview(a::StridedView, I::SliceIndex) = getindex(sreshape(a, (length(a),)), I)
Base.view(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...)

# `sview` can be used as a constructor when acting on `AbstractArray` objects
@inline function sview(a::AbstractArray{<:Any, N}, I::Vararg{SliceIndex, N}) where {N}
function sview(a::AbstractArray{<:Any, N}, I::Vararg{SliceIndex, N}) where {N}
return getindex(StridedView(a), I...)
end
@inline function sview(a::AbstractArray, I::SliceIndex)
function sview(a::AbstractArray, I::SliceIndex)
Comment thread
lkdvos marked this conversation as resolved.
return getindex(sreshape(StridedView(a), (length(a),)), I)
end

Expand All @@ -248,10 +248,9 @@ function Base.show(io::IO, e::ReshapeException)
return print(io, msg)
end

# we cannot use Base.reshape, as this also accepts indices that might not preserve
# stridedness
# we cannot use Base.reshape, as this also accepts indices that might not preserve stridedness
sreshape(a, args::Vararg{Int}) = sreshape(a, args)
@inline function sreshape(a::StridedView{T}, newsize::Dims) where {T}
function sreshape(a::StridedView{T}, newsize::Dims) where {T}
if any(isequal(0), newsize)
any(isequal(0), size(a)) || throw(DimensionMismatch())
newstrides = one.(newsize)
Expand Down
3 changes: 0 additions & 3 deletions test/jet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,5 @@ name = "StridedViewsJETTest"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"

[sources]
StridedViews = {path = "../.."}

[compat]
JET = "0.9, 0.10, 0.11"
1 change: 1 addition & 0 deletions test/jet/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import Pkg
try
Pkg.activate(joinpath(@__DIR__); io = devnull)
Pkg.develop(Pkg.PackageSpec(path = joinpath(@__DIR__, "..", "..")); io = devnull)
Pkg.instantiate(; io = devnull)
@eval import JET
JET.test_package(StridedViews; target_modules = (StridedViews,))
Expand Down
Loading