Skip to content

Bounds-checking in triangular indexing branches #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
120 changes: 83 additions & 37 deletions src/triangular.jl
Original file line number Diff line number Diff line change
@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false

@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
ifelse(i == j, oneunit(T), zero(T))
end
end
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
@inbounds diagzero(A,i,j)
end
end

_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
@@ -250,63 +262,97 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0

# these specialized getindex methods enable constant-propagation of the band
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
ifelse(b.band == 0, oneunit(T), zero(T))
end
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
@inbounds diagzero(A, b)
end
end

_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"

@noinline function throw_nonzeroerror(T, @nospecialize(x), i, j)
Ts = _zero_triangular_half_str(T)
Tn = nameof(T)
@noinline function throw_nonzeroerror(Tn::Symbol, @nospecialize(x), i, j)
zero_half = Tn in (:UpperTriangular, :UnitUpperTriangular) ? "lower" : "upper"
nstr = Tn === :UpperTriangular ? "n" : ""
throw(ArgumentError(
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
LazyString(
lazy"cannot set index ($i, $j) in the $zero_half triangular part ",
lazy"of a$nstr $Tn matrix to a nonzero value ($x)")
)
)
end
@noinline function throw_nononeerror(T, @nospecialize(x), i, j)
Tn = nameof(T)
@noinline function throw_nonuniterror(Tn::Symbol, @nospecialize(x), i, j)
throw(ArgumentError(
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
lazy"cannot set index ($i, $j) on the diagonal of a $Tn matrix to a non-unit value ($x)"))
end

@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end

@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end
@@ -560,7 +606,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
@eval @inline function _copy!(A::$UT, B::$T)
for dind in diagind(A, IndexStyle(A))
if A[dind] != B[dind]
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
throw_nonuniterror(nameof(typeof(A)), B[dind], Tuple(dind)...)
end
end
_copy!($T(parent(A)), B)
@@ -741,7 +787,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
@@ -752,7 +798,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
@@ -783,7 +829,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
@@ -794,7 +840,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
79 changes: 77 additions & 2 deletions test/triangular.jl
Original file line number Diff line number Diff line change
@@ -641,11 +641,11 @@ end
@testset "error message" begin
A = UpperTriangular(Ap)
B = UpperTriangular(Bp)
@test_throws "cannot set index in the lower triangular part" copyto!(A, B)
@test_throws "cannot set index (3, 1) in the lower triangular part" copyto!(A, B)

A = LowerTriangular(Ap)
B = LowerTriangular(Bp)
@test_throws "cannot set index in the upper triangular part" copyto!(A, B)
@test_throws "cannot set index (1, 2) in the upper triangular part" copyto!(A, B)
end
end

@@ -950,6 +950,10 @@ end
@test 2\U == 2\M
@test U*2 == M*2
@test 2*U == 2*M

U2 = copy(U)
@test rmul!(U, 1) == U2
@test lmul!(1, U) == U2
end

@testset "scaling partly initialized unit triangular" begin
@@ -966,4 +970,75 @@ end
end
end

@testset "indexing checks" begin
P = [1 2; 3 4]
@testset "getindex" begin
U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0]
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(-1,0)]

U = UpperTriangular(P)
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(-1,0)]

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0]
@test_throws BoundsError L[0,1]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(1,0)]

L = LowerTriangular(P)
@test_throws BoundsError L[0,1]
@test_throws BoundsError L[BandIndex(1,0)]
end
@testset "setindex!" begin
A = SizedArrays.SizedArray{(2,2)}(P)
M = fill(A, 2, 2)
U = UnitUpperTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitUpperTriangular matrix to a non-unit value"
@test_throws non_unit_msg U[1,1] = A
L = UnitLowerTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitLowerTriangular matrix to a non-unit value"
@test_throws non_unit_msg L[1,1] = A

for UT in (UnitUpperTriangular, UpperTriangular)
U = UT(M)
@test_throws "Cannot `convert` an object of type $Int" U[2,1] = 0
end
for LT in (UnitLowerTriangular, LowerTriangular)
L = LT(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,2] = 0
end

U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0] = 1
@test_throws BoundsError U[1,0] = 0

U = UpperTriangular(P)
@test_throws BoundsError U[1,0] = 0

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0] = 1
@test_throws BoundsError L[0,1] = 0

L = LowerTriangular(P)
@test_throws BoundsError L[0,1] = 0
end
end

@testset "unit triangular l/rdiv!" begin
A = rand(3,3)
@testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular),
(UnitLowerTriangular, LowerTriangular))
UnitTri = UT(A)
Tri = T(LinearAlgebra.full(UnitTri))
@test 2 \ UnitTri ≈ 2 \ Tri
@test UnitTri / 2 ≈ Tri / 2
end
end

end # module TestTriangular