Skip to content

Commit 1a9925d

Browse files
MasonProttertkf
andauthored
Special case empty covec-diagonal-vec product (#35557)
Co-Authored-By: Takafumi Arakaki <[email protected]>
1 parent ecc0c43 commit 1a9925d

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,14 +659,19 @@ end
659659
# disambiguation methods: * of Diagonal and Adj/Trans AbsVec
660660
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
661661
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
662-
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
663-
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
664-
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
665-
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
666-
function dot(x::AbstractVector, D::Diagonal, y::AbstractVector)
667-
mapreduce(t -> dot(t[1], t[2], t[3]), +, zip(x, D.diag, y))
662+
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
663+
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
664+
dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y)
665+
666+
function _mapreduce_prod(f, x, D::Diagonal, y)
667+
if isempty(x) && isempty(D) && isempty(y)
668+
return zero(Base.promote_op(f, eltype(x), eltype(D), eltype(y)))
669+
else
670+
return mapreduce(t -> f(t[1], t[2], t[3]), +, zip(x, D.diag, y))
671+
end
668672
end
669673

674+
670675
function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
671676
info = 0
672677
for (i, di) in enumerate(A.diag)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,4 +711,10 @@ end
711711
@test s1 == prod(sign, d)
712712
end
713713

714+
@testset "Empty (#35424)" begin
715+
@test zeros(0)'*Diagonal(zeros(0))*zeros(0) === 0.0
716+
@test transpose(zeros(0))*Diagonal(zeros(Complex{Int}, 0))*zeros(0) === 0.0 + 0.0im
717+
@test dot(zeros(Int32, 0), Diagonal(zeros(Int, 0)), zeros(Int16, 0)) === 0
718+
end
719+
714720
end # module TestDiagonal

0 commit comments

Comments
 (0)