Skip to content

Commit a7a7143

Browse files
committed
Merge branch 'master' into wct/pseudo-observation-parametrisations
2 parents 2d91393 + 490ece8 commit a7a7143

File tree

10 files changed

+180
-89
lines changed

10 files changed

+180
-89
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApproximateGPs"
22
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.3.2"
4+
version = "0.3.4"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
@@ -13,11 +13,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
1414
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1617
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1719
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1820
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1921
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2022
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2124

2225
[compat]
2326
AbstractGPs = "0.3, 0.4, 0.5"
@@ -28,6 +31,7 @@ FillArrays = "0.12, 0.13"
2831
ForwardDiff = "0.10"
2932
GPLikelihoods = "0.3"
3033
IrrationalConstants = "0.1"
34+
LogExpFunctions = "0.3"
3135
PDMats = "0.11"
3236
Reexport = "1"
3337
SpecialFunctions = "1, 2"

docs/literate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ using Pkg: Pkg
1111

1212
using InteractiveUtils
1313
const EXAMPLEPATH = joinpath(@__DIR__, "..", "examples", EXAMPLE)
14+
const PKGDIR = joinpath(@__DIR__, "..")
1415
Pkg.activate(EXAMPLEPATH)
1516
Pkg.instantiate()
1617
pkg_status = sprint() do io
1718
Pkg.status(; io=io)
1819
end
20+
Pkg.develop(Pkg.PackageSpec(path=PKGDIR))
1921

2022
using Literate: Literate
2123

examples/c-comparisons/script.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ lf2.f.kernel
127127
# Finally, we need to construct again the (approximate) posterior given the
128128
# observations for the latent GP with optimised hyperparameters:
129129

130-
f_post2 = posterior(LaplaceApproximation(; f_init=objective.f), lf2(X), Y)
130+
f_post2 = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf2(X), Y)
131131

132-
# By passing `f_init=objective.f` we let the Laplace approximation "warm-start"
133-
# at the last point of the inner-loop Newton optimisation; `objective.f` is a
132+
# By passing `f_init=objective.cache.f` we let the Laplace approximation "warm-start"
133+
# at the last point of the inner-loop Newton optimisation; `objective.cache` is a
134134
# field on the `objective` closure.
135135

136136
# Let's plot samples from the approximate posterior for the optimised hyperparameters:

src/ApproximateGPs.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ApproximateGPs
22

33
using Reexport
44

5+
@reexport using AbstractGPs
56
@reexport using GPLikelihoods
67

78
include("API.jl")
@@ -26,4 +27,6 @@ include("LaplaceApproximationModule.jl")
2627

2728
include("deprecations.jl")
2829

30+
include("TestUtils.jl")
31+
2932
end

src/LaplaceApproximationModule.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,25 @@ closure passes its arguments to `build_latent_gp`, which must return the
7575
- `newton_maxiter=100`: maximum number of Newton steps.
7676
"""
7777
function build_laplace_objective(build_latent_gp, xs, ys; kwargs...)
78-
# TODO assumes type of `xs` will be same as `mean(lfx.fx)`
79-
f = similar(xs, length(xs)) # will be mutated in-place to "warm-start" the Newton steps
80-
return build_laplace_objective!(f, build_latent_gp, xs, ys; kwargs...)
78+
cache = LaplaceObjectiveCache(nothing)
79+
# cache.f will be mutated in-place to "warm-start" the Newton steps
80+
# f should be similar(mean(lfx.fx)), but to construct lfx we would need the arguments
81+
# so we set it to `nothing` initially, and set it to mean(lfx.fx) within the objective
82+
return build_laplace_objective!(cache, build_latent_gp, xs, ys; kwargs...)
83+
end
84+
85+
function build_laplace_objective!(f_init::Vector, build_latent_gp, xs, ys; kwargs...)
86+
return build_laplace_objective!(
87+
LaplaceObjectiveCache(f_init), build_latent_gp, xs, ys; kwargs...
88+
)
89+
end
90+
91+
mutable struct LaplaceObjectiveCache
92+
f::Union{Nothing,Vector}
8193
end
8294

8395
function build_laplace_objective!(
84-
f,
96+
cache::LaplaceObjectiveCache,
8597
build_latent_gp,
8698
xs,
8799
ys;
@@ -98,16 +110,18 @@ function build_laplace_objective!(
98110
# Zygote does not like the try/catch within @info etc.
99111
@debug "Objective arguments: $args"
100112
# Zygote does not like in-place assignments either
101-
if initialize_f
102-
f .= mean(lfx.fx)
113+
if cache.f === nothing
114+
cache.f = mean(lfx.fx)
115+
elseif initialize_f
116+
cache.f .= mean(lfx.fx)
103117
end
104118
end
105119
f_opt, lml = laplace_f_and_lml(
106-
lfx, ys; f_init=f, maxiter=newton_maxiter, callback=newton_callback
120+
lfx, ys; f_init=cache.f, maxiter=newton_maxiter, callback=newton_callback
107121
)
108122
ignore_derivatives() do
109123
if newton_warmstart
110-
f .= f_opt
124+
cache.f .= f_opt
111125
initialize_f = false
112126
end
113127
end

src/SparseVariationalApproximationModule.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using AbstractGPs:
2222
FiniteGP,
2323
LatentFiniteGP,
2424
ApproxPosteriorGP,
25+
elbo,
2526
posterior,
2627
marginals,
2728
At_A,
@@ -292,7 +293,7 @@ function API.approx_lml(
292293
sva::AbstractSparseVariationalApproximation, l_fx::Union{FiniteGP,LatentFiniteGP}, ys;
293294
kwargs...
294295
)
295-
return elbo(sva, l_fx, ys; kwargs...)
296+
return AbstractGPs.elbo(sva, l_fx, ys; kwargs...)
296297
end
297298

298299
_get_prior(approx::SparseVariationalApproximation) = approx.fz.f

src/TestUtils.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
module TestUtils
2+
3+
using LinearAlgebra
4+
using Random
5+
using Test
6+
7+
using Distributions
8+
using LogExpFunctions: logistic, softplus
9+
10+
using AbstractGPs
11+
using ApproximateGPs
12+
13+
function generate_data()
14+
X = range(0, 23.5; length=48)
15+
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
16+
# The generating code below is only kept for illustrative purposes.
17+
#! format: off
18+
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
19+
#! format: on
20+
# Random.seed!(1)
21+
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
22+
# # invlink = normcdf
23+
# invlink = logistic
24+
# ps = invlink.(fs)
25+
# Y = @. rand(Bernoulli(ps))
26+
return X, Y
27+
end
28+
29+
dist_y_given_f(f) = Bernoulli(logistic(f))
30+
31+
function build_latent_gp(theta)
32+
variance = softplus(theta[1])
33+
lengthscale = softplus(theta[2])
34+
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
35+
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
36+
end
37+
38+
"""
39+
test_approx_lml(approx)
40+
41+
Test whether in the conjugate case `approx_lml(approx, LatentGP(f,
42+
GaussianLikelihood(), jitter)(x), y)` gives approximately the same answer as
43+
the log marginal likelihood in exact GP regression.
44+
45+
!!! todo
46+
Not yet implemented.
47+
48+
Will not necessarily work for approximations that rely on optimization such
49+
as `SparseVariationalApproximation`.
50+
51+
!!! todo
52+
Also test gradients (for hyperparameter optimization).
53+
"""
54+
function test_approx_lml end
55+
56+
"""
57+
test_approximation_predictions(approx)
58+
59+
Test whether the prediction interface for `approx` works and whether in the
60+
conjugate case `posterior(approx, LatentGP(f, GaussianLikelihood(), jitter)(x), y)`
61+
gives approximately the same answer as the exact GP regression posterior.
62+
63+
!!! note
64+
Should be satisfied by all approximate inference methods, but note that
65+
this does not currently apply for some approximations which rely on
66+
optimization such as `SparseVariationalApproximation`.
67+
68+
!!! warning
69+
Do not rely on this as the only test of a new approximation!
70+
71+
See `test_approx_lml`.
72+
"""
73+
function test_approximation_predictions(approx)
74+
rng = MersenneTwister(123456)
75+
N_cond = 5
76+
N_a = 6
77+
N_b = 7
78+
79+
# Specify prior.
80+
f = GP(Matern32Kernel())
81+
# Sample from prior.
82+
x = collect(range(-1.0, 1.0; length=N_cond))
83+
# TODO: Change to x = ColVecs(rand(2, N_cond)) once #109 is fixed
84+
noise_scale = 0.1
85+
fx = f(x, noise_scale^2)
86+
y = rand(rng, fx)
87+
88+
jitter = 0.0 # not needed in Gaussian case
89+
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
90+
f_approx_post = posterior(approx, lf(x), y)
91+
92+
@testset "AbstractGPs API" begin
93+
a = collect(range(-1.2, 1.2; length=N_a))
94+
b = randn(rng, N_b)
95+
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f_approx_post, a, b)
96+
end
97+
98+
@testset "exact GPR equivalence for Gaussian likelihood" begin
99+
f_exact_post = posterior(f(x, noise_scale^2), y)
100+
xt = vcat(x, randn(rng, 3)) # test at training and new points
101+
102+
m_approx, c_approx = mean_and_cov(f_approx_post(xt))
103+
m_exact, c_exact = mean_and_cov(f_exact_post(xt))
104+
105+
@test m_approx m_exact
106+
@test c_approx c_exact
107+
end
108+
end
109+
110+
end

test/LaplaceApproximationModule.jl

Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,7 @@
11
@testset "laplace" begin
2-
function generate_data()
3-
X = range(0, 23.5; length=48)
4-
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
5-
# The generating code below is only kept for illustrative purposes.
6-
#! format: off
7-
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
8-
#! format: on
9-
# Random.seed!(1)
10-
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
11-
# # invlink = normcdf
12-
# invlink = logistic
13-
# ps = invlink.(fs)
14-
# Y = [rand(Bernoulli(p)) for p in ps]
15-
return X, Y
16-
end
17-
18-
dist_y_given_f(f) = Bernoulli(logistic(f))
19-
20-
function build_latent_gp(theta)
21-
variance = softplus(theta[1])
22-
lengthscale = softplus(theta[2])
23-
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
24-
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
25-
end
2+
generate_data = ApproximateGPs.TestUtils.generate_data
3+
dist_y_given_f = ApproximateGPs.TestUtils.dist_y_given_f
4+
build_latent_gp = ApproximateGPs.TestUtils.build_latent_gp
265

276
function optimize_elbo(
287
build_latent_gp,
@@ -44,48 +23,13 @@
4423
)
4524

4625
lf = build_latent_gp(training_results.minimizer)
47-
f_post = posterior(LaplaceApproximation(; f_init=objective.f), lf(xs), ys)
26+
f_post = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf(xs), ys)
4827
return f_post, training_results
4928
end
5029

5130
@testset "predictions" begin
52-
rng = MersenneTwister(123456)
53-
N_cond = 5
54-
N_a = 6
55-
N_b = 7
56-
57-
# Specify prior.
58-
f = GP(Matern32Kernel())
59-
# Sample from prior.
60-
x = collect(range(-1.0, 1.0; length=N_cond))
61-
noise_scale = 0.1
62-
fx = f(x, noise_scale^2)
63-
y = rand(rng, fx)
64-
65-
jitter = 0.0 # not needed in Gaussian case
66-
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
67-
# in Gaussian case, Laplace converges to f_opt in one step; we need the
68-
# second step to compute the cache at f_opt rather than f_init!
69-
f_approx_post = posterior(LaplaceApproximation(; maxiter=2), lf(x), y)
70-
71-
@testset "AbstractGPs API" begin
72-
a = collect(range(-1.2, 1.2; length=N_a))
73-
b = randn(rng, N_b)
74-
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
75-
rng, f_approx_post, a, b
76-
)
77-
end
78-
79-
@testset "equivalence to exact GPR for Gaussian likelihood" begin
80-
f_exact_post = posterior(f(x, noise_scale^2), y)
81-
xt = vcat(x, randn(rng, 3)) # test at training and new points
82-
83-
m_approx, c_approx = mean_and_cov(f_approx_post(xt))
84-
m_exact, c_exact = mean_and_cov(f_exact_post(xt))
85-
86-
@test m_approx m_exact
87-
@test c_approx c_exact
88-
end
31+
approx = LaplaceApproximation(; maxiter=2)
32+
ApproximateGPs.TestUtils.test_approximation_predictions(approx)
8933
end
9034

9135
@testset "gradients" begin
@@ -264,4 +208,14 @@
264208
res = res_array[end]
265209
@test res.q isa MvNormal
266210
end
211+
212+
@testset "GitHub issue #109" begin
213+
build_latent_gp() = LatentGP(GP(SEKernel()), BernoulliLikelihood(), 1e-8)
214+
215+
x = ColVecs(randn(2, 5))
216+
_, y = rand(build_latent_gp()(x))
217+
218+
objective = build_laplace_objective(build_latent_gp, x, y)
219+
_ = objective() # check that it works
220+
end
267221
end

test/SparseVariationalApproximationModule.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
b = randn(rng, N_b)
2929

3030
@testset "AbstractGPs interface - Centered" begin
31-
TestUtils.test_internal_abstractgps_interface(rng, f_approx_post_Centered, a, b)
31+
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
32+
rng, f_approx_post_Centered, a, b
33+
)
3234
end
3335

3436
@testset "NonCentered" begin
@@ -50,7 +52,7 @@
5052
f_approx_post_non_Centered = posterior(approx_non_Centered)
5153

5254
@testset "AbstractGPs interface - NonCentered" begin
53-
TestUtils.test_internal_abstractgps_interface(
55+
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
5456
rng, f_approx_post_non_Centered, a, b
5557
)
5658
end
@@ -170,7 +172,7 @@
170172

171173
# Train the SVGP model
172174
data = [(x, y)]
173-
opt = ADAM(0.001)
175+
opt = Flux.ADAM(0.001)
174176

175177
svgp_ps = Flux.params(svgp_model)
176178

0 commit comments

Comments
 (0)