diff --git a/ext/StridedAMDGPUExt.jl b/ext/StridedAMDGPUExt.jl index 627446e..e14d804 100644 --- a/ext/StridedAMDGPUExt.jl +++ b/ext/StridedAMDGPUExt.jl @@ -1,7 +1,7 @@ module StridedAMDGPUExt using Strided, StridedViews, AMDGPU, AMDGPU.rocBLAS, LinearAlgebra -import Strided: blas_mul! +import Strided: blas_mul!, _get_op const ROCStridedView{T, N, A <: ROCArray{T}} = StridedViews.StridedView{T, N, A} @@ -16,4 +16,12 @@ function Strided.blas_mul!(C::ROCStridedView{T, 2}, A::ROCStridedView{T, 2}, B:: return C end +_conj(x) = real(x) - imag(x) * im +@static if VERSION < v"1.11.0-rc" + # work around compiler issue on AMD on 1.10 + _get_op(A::ROCStridedView) = A.op == conj ? _conj : A.op +else + _get_op(A::ROCStridedView) = A.op +end + end diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 04728e8..b65313d 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -5,7 +5,7 @@ using GPUArrays: Adapt, KernelAbstractions using GPUArrays.KernelAbstractions: @kernel, @index using StridedViews: ParentIndex -import Strided: isblasmatrix +import Strided: isblasmatrix, _get_op ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)} @@ -125,9 +125,8 @@ function Strided._mapreduce_block!( backend = KernelAbstractions.get_backend(parent(out)) kernel! = _mapreduce_gpu_kernel!(backend) - ops = getproperty.(arrays, :op) + ops = _get_op.(arrays) kernel!(f, op, initop, dims_red, strides, offsets, ops, parent.(arrays); ndrange = dims_out) - return nothing end diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 0db5dcf..9182ee8 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -7,6 +7,15 @@ LinearAlgebra.transpose!(C::StridedView, A::StridedView) = copy!(C, transpose(A) Base.permutedims!(dst::StridedView, src::StridedView, p) = copy!(dst, permutedims(src, p)) Base.fill!(A::StridedView, val) = map!(Returns(val), A, A) +# This is a wrapper function intended to allow us to +# intercept "conj" and rewrite it in cases where the +# GPU compiler seemingly isn't compiling bare conj +# correctly (https://github.com/QuantumKitHub/Strided.jl/issues/63). +# It should be removed as soon as the underlying +# compilation problem is resolved. It uses the first +# argument to dispatch so that only AMD arrays are affected. +_get_op(A::StridedView) = A.op + function Base.mapreduce(f, op, A::StridedView; dims = :, kw...) return Base._mapreduce_dim(f, op, values(kw), A, dims) end