Skip to content

Commit d0a1ac6

Browse files
Will Tebbuttgithub-actions[bot]
andauthored
positive(::Array) (#54)
* PositiveArray implementation * Add empty tests file * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Correctness testing * Add performance tests * Check poor performance of naive approach * Document positive more thoroughly * Formatting * Bump patch version * Improve test comments * Include allocation counter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 3ff6a53 commit d0a1ac6

File tree

6 files changed

+98
-2
lines changed

6 files changed

+98
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParameterHandling"
22
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.4.3"
4+
version = "0.4.4"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ In particular, we've seen an example of how ParameterHandling.jl can be used to
276276
gap between the "flat" representation of parameters that `Optim` likes to work with, and the
277277
"structured" representation that it's convenient to write optimisation algorithms with.
278278

279-
# Gotchas
279+
# Gotchas and Performance Tips
280280

281281
1. `Integer`s typically don't take part in the kind of optimisation procedures that this package is designed to handle. Consequently, `flatten(::Integer)` produces an empty vector.
282282
2. `deferred` has some type-stability issues when used in conjunction with abstract types. For example, `flatten(deferred(Normal, 5.0, 4.0))` won't infer properly. A simple work around is to write a function `normal(args...) = Normal(args...)` and work with `deferred(normal, 5.0, 4.0)` instead.
283+
3. Let `x` be an `Array{<:Real}`. If you wish to constrain each of its values to be positive, prefer `positive(x)` over `map(positive, x)` or `positive.(x)`. `positive(x)` has been implemented the associated `unflatten` function has good performance, particularly when interacting with `Zygote` (when `map(positive, x)` is extremely slow).

src/ParameterHandling.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include("parameters_base.jl")
1515
include("parameters_meta.jl")
1616
include("parameters_scalar.jl")
1717
include("parameters_matrix.jl")
18+
include("parameters_array.jl")
1819

1920
include("test_utils.jl")
2021

src/parameters_array.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
struct PositiveArray{T<:Array{<:Real},V,Tε<:Real} <: AbstractParameter
2+
unconstrained_value::T
3+
transform::V
4+
ε::Tε
5+
end
6+
7+
value(x::PositiveArray) = map(exp, x.unconstrained_value) .+ x.ε
8+
9+
function flatten(::Type{T}, x::PositiveArray{<:Array{V}}) where {T<:Real,V<:Real}
10+
v, unflatten_to_array = flatten(T, x.unconstrained_value)
11+
transform = x.transform
12+
ε = x.ε
13+
function unflatten_PositiveArray(v::AbstractVector{T})
14+
return PositiveArray(unflatten_to_array(v), transform, ε)
15+
end
16+
return v, unflatten_PositiveArray
17+
end
18+
19+
"""
20+
positive(x::Array{<:Real})
21+
22+
Roughly equivalent to `map(positive, x)`, but implemented such that unflattening can be
23+
efficiently differentiated through using algorithmic differentiation (Zygote in particular).
24+
"""
25+
function positive(val::Array{<:Real}, transform=exp, ε=sqrt(eps(eltype(val))))
26+
all(val .> 0) || throw(ArgumentError("Not all elements of val are positive."))
27+
all(val .> ε) || throw(ArgumentError("Not all elements of val greater than ε ()."))
28+
29+
inverse_transform = inverse(transform)
30+
unconstrained_value = map(x -> inverse_transform(x - ε), val)
31+
return PositiveArray(
32+
unconstrained_value, transform, convert(eltype(unconstrained_value), ε)
33+
)
34+
end

test/parameters_array.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
@testset "parameters_array" begin
2+
@testset "postive" begin
3+
@testset "$val" for val in [[5.0, 4.0], [0.001f0], fill(1e-7, 1, 2)]
4+
p = positive(val)
5+
test_parameter_interface(p)
6+
@test value(p) val
7+
@test typeof(value(p)) === typeof(val)
8+
end
9+
10+
# Test edge cases around the size of the value relative to the error tol.
11+
@test_throws ArgumentError positive([-0.1, 0.1])
12+
@test_throws ArgumentError positive(fill(1e-12, 1, 2, 3))
13+
@test value(positive(fill(1e-11, 3, 2, 1), exp, 1e-12)) fill(1e-11, 3, 2, 1)
14+
15+
# These tests assume that if the number of allocations is roughly constant in the
16+
# size of `x`, then performance is acceptable. This is demonstrated by requiring
17+
# that the number of allocations (100) is a lot smaller than the total length of
18+
# the array in question (1_000_000). The bound (100) is quite loose because there
19+
# are typically serveral 10s of allocations made by Zygote for book-keeping
20+
# purposes etc.
21+
@testset "zygote performance" begin
22+
x = rand(1000, 1000) .+ 0.1
23+
flat_x, unflatten = value_flatten(positive(x))
24+
25+
# primal evaluation
26+
count_allocs(unflatten, flat_x)
27+
@test count_allocs(unflatten, flat_x) < 100
28+
29+
# forward evaluation
30+
count_allocs(Zygote.pullback, unflatten, flat_x)
31+
@test count_allocs(Zygote.pullback, unflatten, flat_x) < 100
32+
33+
# pullback
34+
out, pb = Zygote.pullback(unflatten, flat_x)
35+
count_allocs(pb, out)
36+
@test count_allocs(pb, out) < 100
37+
end
38+
39+
# Check that this optimisation is actually necessary -- i.e. that the performance
40+
# of the equivalent operation, `map(positive, x)`, is indeed poor, esp. with AD.
41+
# Poor performance is demonstrated by showing that there's at least one allocation
42+
# per element. A smaller array than the previous test set is used because it can
43+
# be _really_ slow for large arrays (several seconds), which is undesirable in
44+
# unit tests.
45+
@testset "zygote performance of scalar equivalent" begin
46+
x = rand(1000) .+ 0.1
47+
flat_x, unflatten = value_flatten(map(positive, x))
48+
49+
# forward evaluation
50+
count_allocs(Zygote.pullback, unflatten, flat_x)
51+
count_allocs(Zygote.pullback, unflatten, flat_x) > 1000
52+
end
53+
end
54+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,16 @@ using ParameterHandling.TestUtils: test_flatten_interface, test_parameter_interf
1414

1515
const tuple_infers = VERSION < v"1.5" ? false : true
1616

17+
function count_allocs(f, args...)
18+
stats = @timed f(args...)
19+
return Base.gc_alloc_count(stats.gcstats)
20+
end
21+
1722
@testset "ParameterHandling.jl" begin
1823
include("flatten.jl")
1924
include("parameters.jl")
2025
include("parameters_meta.jl")
2126
include("parameters_scalar.jl")
2227
include("parameters_matrix.jl")
28+
include("parameters_array.jl")
2329
end

0 commit comments

Comments
 (0)