-
Notifications
You must be signed in to change notification settings - Fork 24
Closed
Description
When trying to differentiate logpdf or other scalar functions with a parameterized mean function and multidimensional input, there are errors:
using AbstractGPs
using Zygote
pars = [1., 0.]
function build_model(pars)
a, b = pars
return GP(x -> a * first(x) + b, SEKernel())
end
rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)
function test_logpdf(pars)
f = build_model(pars)
x, y = rand_data(10)
return logpdf(f(x, 1e-3), y)
end
test_logpdf(pars)
Zygote.gradient(test_logpdf, pars) # works
function test_logpdf2(pars)
f = build_model(pars)
x, y = rand_data_2d(10)
return logpdf(f(x, 1e-3), y)
end
test_logpdf2(pars)
Zygote.gradient(test_logpdf2, pars)
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})
function test_mean(pars)
f = build_model(pars)
x, _ = rand_data_2d(10)
return sum(mean(f(x, 1e-3)))
end
test_mean(pars)
Zygote.gradient(test_mean, pars) # ERROR: Pullback on AbstractVector{<:AbstractVector}.
function test_post_mean(pars)
f = build_model(pars)
x, y = rand_data_2d(10)
fp = posterior(f(x, 1e-3), y)
return sum(mean(fp(x, 1e-3)))
end
test_post_mean(pars)
Zygote.gradient(test_post_mean, pars)
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})Is there a simple fix? The error for test_mean gives a suggestion to overload a kernelmatrix method, but that does not seem to be the issue since we are talking about the mean here. Why does the existing rrule for RowVecs not suffice?
Metadata
Metadata
Assignees
Labels
No labels