diff --git a/Project.toml b/Project.toml index b2aeba2..68fd3cf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "StridedViews" uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" authors = ["Lukas Devos ", "Jutho Haegeman "] -version = "0.5.1" +version = "0.5.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -28,6 +29,7 @@ CUDACore = "6" JLArrays = "0.3.1" LinearAlgebra = "1" Metal = "1" +PrecompileTools = "1.1" PtrArrays = "1.2.0" julia = "1.10" diff --git a/src/StridedViews.jl b/src/StridedViews.jl index cfa01a9..dc2e1c2 100644 --- a/src/StridedViews.jl +++ b/src/StridedViews.jl @@ -9,5 +9,6 @@ export StridedView, sreshape, sview, isstrided include("auxiliary.jl") include("stridedview.jl") +include("precompile.jl") end diff --git a/src/precompile.jl b/src/precompile.jl new file mode 100644 index 0000000..2aae20c --- /dev/null +++ b/src/precompile.jl @@ -0,0 +1,42 @@ +# Precompilation workload +# ------------------------ +# Cache the core `StridedView` specializations for the BLAS element types over a small +# range of dimensionalities and the op-wrappers a `StridedView` can carry. These are the +# specializations that downstream packages (e.g. TensorOperations / Strided) hit on their +# first call, so warming them here removes that first-call latency. +using PrecompileTools: @setup_workload, @compile_workload + +@setup_workload begin + @compile_workload begin + for T in (Float32, Float64, ComplexF32, ComplexF64) + # construction + property queries + core ops for ndims 1:4 + for N in 1:6 + A = Array{T, N}(undef, ntuple(_ -> 2, N)) + sv = StridedView(A) + size(sv) + strides(sv) + offset(sv) + csv = conj(sv) + # permute through the identity permutation (exercises the per-N path) + permutedims(sv, ntuple(identity, N)) + permutedims(csv, ntuple(identity, N)) + # reshape to a flat vector and back (also exercises sview on the flat view) + flat = sreshape(sv, (length(sv),)) + sview(flat, 1:length(sv)) + getindex(sv, ntuple(_ -> 1, N)...) + flat = sreshape(csv, (length(sv),)) + sview(flat, 1:length(csv)) + getindex(csv, ntuple(_ -> 1, N)...) + end + # 2D matrix wrappers: transpose / adjoint + M = Array{T, 2}(undef, 2, 2) + svM = StridedView(M) + transpose(svM) + adjoint(svM) + # a representative 4D slice view (the SliceIndex `getindex` construction path) + A4 = Array{T, 4}(undef, 2, 2, 2, 2) + sv4 = StridedView(A4) + getindex(sv4, :, 1:2, 1, 1:2) + end + end +end diff --git a/src/stridedview.jl b/src/stridedview.jl index b745574..b7e93c0 100644 --- a/src/stridedview.jl +++ b/src/stridedview.jl @@ -138,8 +138,8 @@ end return a end -# Indexing with slice indices to create a new view -@inline function Base.getindex(a::StridedView{T, N}, I::Vararg{SliceIndex, N}) where {T, N} +# Indexing with slice indices to create a new view. +function Base.getindex(a::StridedView{T, N}, I::Vararg{SliceIndex, N}) where {T, N} return StridedView{T}( a.parent, _computeviewsize(a.size, I), @@ -179,7 +179,7 @@ function Base.conj(a::StridedView) return StridedView{T}(a.parent, a.size, a.strides, a.offset, newop) end -@inline function Base.permutedims(a::StridedView{T, N}, p) where {T, N} +function Base.permutedims(a::StridedView{T, N}, p) where {T, N} _isperm(N, p) || throw(ArgumentError("Invalid permutation of length $N: $p")) newsize = ntuple(n -> size(a, p[n]), Val(N)) newstrides = ntuple(n -> stride(a, p[n]), Val(N)) @@ -228,10 +228,10 @@ sview(a::StridedView, I::SliceIndex) = getindex(sreshape(a, (length(a),)), I) Base.view(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...) # `sview` can be used as a constructor when acting on `AbstractArray` objects -@inline function sview(a::AbstractArray{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} +function sview(a::AbstractArray{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} return getindex(StridedView(a), I...) end -@inline function sview(a::AbstractArray, I::SliceIndex) +function sview(a::AbstractArray, I::SliceIndex) return getindex(sreshape(StridedView(a), (length(a),)), I) end @@ -248,10 +248,9 @@ function Base.show(io::IO, e::ReshapeException) return print(io, msg) end -# we cannot use Base.reshape, as this also accepts indices that might not preserve -# stridedness +# we cannot use Base.reshape, as this also accepts indices that might not preserve stridedness sreshape(a, args::Vararg{Int}) = sreshape(a, args) -@inline function sreshape(a::StridedView{T}, newsize::Dims) where {T} +function sreshape(a::StridedView{T}, newsize::Dims) where {T} if any(isequal(0), newsize) any(isequal(0), size(a)) || throw(DimensionMismatch()) newstrides = one.(newsize) diff --git a/test/jet/Project.toml b/test/jet/Project.toml index 0c19e5a..68090c6 100644 --- a/test/jet/Project.toml +++ b/test/jet/Project.toml @@ -4,8 +4,5 @@ name = "StridedViewsJETTest" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -[sources] -StridedViews = {path = "../.."} - [compat] JET = "0.9, 0.10, 0.11" diff --git a/test/jet/jet.jl b/test/jet/jet.jl index 2477de8..4955531 100644 --- a/test/jet/jet.jl +++ b/test/jet/jet.jl @@ -2,6 +2,7 @@ import Pkg try Pkg.activate(joinpath(@__DIR__); io = devnull) + Pkg.develop(Pkg.PackageSpec(path = joinpath(@__DIR__, "..", "..")); io = devnull) Pkg.instantiate(; io = devnull) @eval import JET JET.test_package(StridedViews; target_modules = (StridedViews,))