Skip to content

Commit aaa554f

Browse files
authored
Added belief space binning (#18)
* Add SparseArrays entry * Added binning option and fixed max_steps stopping criteria * Added binning * Stopped updating the bin manager if use_binning was false * added separate test sets for use with and without binning * Removed unnecessary code * added version of `entropy` for `::SparseVector` * Updated binning method + format updates * Update tests with new BinManager
1 parent da290e7 commit aaa554f

File tree

5 files changed

+322
-25
lines changed

5 files changed

+322
-25
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ LinearAlgebra = "1"
1515
POMDPTools = "0.1, 1"
1616
POMDPs = "0.9, 1"
1717
Printf = "1"
18+
SparseArrays = "1"
1819
julia = "1.7"

src/sample.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@ function sample_points(sol::SARSOPSolver, tree::SARSOPTree, b_idx::Int, L, U, t,
1818
V̲, V̄ = tree.V_lower[b_idx], tree.V_upper[b_idx]
1919
γ = discount(tree)
2020

21-
=#TODO: BAD, binning method
22-
if+ sol.kappa*ϵ*γ^(-t) || (V̂ L && max(U, V̲ + ϵ*γ^(-t)))
21+
if sol.use_binning
22+
= get_bin_value(tree, b_idx)
23+
else
24+
=
25+
end
26+
27+
if+ sol.kappa*ϵ*γ^(-t) || (V̂ L && max(U, V̲ + ϵ*γ^(-t)))
2328
return
2429
else
2530
Q̲, Q̄, a′ = max_r_and_q(tree, b_idx)

src/solver.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Base.@kwdef struct SARSOPSolver{LOW,UP} <: Solver
99
init_lower::LOW = BlindLowerBound(bel_res = 1e-2)
1010
init_upper::UP = FastInformedBound(bel_res=1e-2)
1111
prunethresh::Float64= 0.10
12+
use_binning::Bool = true
1213
end
1314

1415
function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP)
@@ -20,7 +21,7 @@ function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP)
2021

2122
t0 = time()
2223
iter = 0
23-
while time()-t0 < solver.max_time && root_diff(tree) > solver.precision
24+
while iter <= solver.max_steps && time()-t0 < solver.max_time && root_diff(tree) > solver.precision
2425
sample!(solver, tree)
2526
backup!(tree)
2627
prune!(solver, tree)

src/tree.jl

Lines changed: 227 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,72 @@ mutable struct PruneData
44
prune_threshold::Float64
55
end
66

7+
struct BinData
8+
bin_value::Matrix{Float64}
9+
bin_count::Matrix{Int}
10+
bin_error::Matrix{Float64}
11+
end
12+
13+
struct BinNode
14+
key::Tuple{Int,Int}
15+
prev_error::Float64
16+
end
17+
18+
struct BinManager
19+
lowest_ub::Float64
20+
num_levels::Int
21+
num_bins_per_level::Vector{Int}
22+
bin_levels_intervals::Vector{NamedTuple{(:ub, :entropy),Tuple{Float64,Float64}}}
23+
bin_levels_nodes::Vector{Dict{Int,BinNode}}
24+
bin_levels::Vector{BinData}
25+
previous_lowerbound::Dict{Int,Float64}
26+
end
27+
28+
function BinManager(Vs_upper::Vector{Float64}, num_bins_per_level=[5, 10])
29+
num_levels = length(num_bins_per_level)
30+
lowest_ub = minimum(Vs_upper)
31+
highest_ub = maximum(Vs_upper)
32+
33+
# [level][:ub|:entropy] => value
34+
bin_levels_intervals = Vector{NamedTuple{(:ub, :entropy),Tuple{Float64,Float64}}}(undef, num_levels)
35+
36+
# [level][b_idx][:key|:prev_error] => (ub_interval_idx, entropy_interval_idx)|previous_error
37+
bin_levels_nodes = Vector{Dict{Int,BinNode}}(undef, num_levels)
38+
39+
# [level][:bin_value|:bin_count|:bin_error][(ub_interval_idx, entropy_interval_idx)] => Float64|Int|Float64
40+
bin_levels = Vector{BinData}(undef, num_levels)
41+
42+
num_states = length(Vs_upper)
43+
max_e = max_entropy(num_states)
44+
for level_i in 1:num_levels
45+
num_bins = num_bins_per_level[level_i]
46+
47+
ub = (highest_ub - lowest_ub) / num_bins
48+
ent = max_e / num_bins
49+
bin_levels_intervals[level_i] = (ub=ub, entropy=ent)
50+
51+
bin_levels_nodes[level_i] = Dict{Int,BinNode}()
52+
53+
bin_levels[level_i] = BinData(
54+
zeros(Float64, num_bins, num_bins), # bin_value
55+
zeros(Int, num_bins, num_bins), # bin_count
56+
zeros(Float64, num_bins, num_bins) # bin_error
57+
)
58+
end
59+
60+
previous_lowerbound = Dict{Int,Float64}() # b_idx => lowerbound
61+
62+
return BinManager(
63+
lowest_ub,
64+
num_levels,
65+
num_bins_per_level,
66+
bin_levels_intervals,
67+
bin_levels_nodes,
68+
bin_levels,
69+
previous_lowerbound
70+
)
71+
end
72+
773
struct SARSOPTree
874
pomdp::ModifiedSparseTabular
975

@@ -20,7 +86,7 @@ struct SARSOPTree
2086

2187
_discount::Float64
2288
is_terminal::BitVector
23-
is_terminal_s::SparseVector{Bool, Int}
89+
is_terminal_s::SparseVector{Bool,Int}
2490

2591
#do we need both b_pruned and ba_pruned? b_pruned might be enough
2692
sampled::Vector{Int} # b_idx
@@ -32,20 +98,23 @@ struct SARSOPTree
3298
prune_data::PruneData
3399

34100
Γ::Vector{AlphaVec{Int}}
101+
102+
use_binning::Bool
103+
bm::BinManager
35104
end
36105

37106

38-
function SARSOPTree(solver, pomdp::POMDP)
107+
function SARSOPTree(solver, pomdp::POMDP; num_bins_per_level=[5, 10])
39108
sparse_pomdp = ModifiedSparseTabular(pomdp)
40109
cache = TreeCache(sparse_pomdp)
41110

42111
upper_policy = solve(solver.init_upper, sparse_pomdp)
43112
corner_values = map(maximum, zip(upper_policy.alphas...))
44113

45-
tree = SARSOPTree(
46-
sparse_pomdp,
114+
bin_manager = BinManager(corner_values, num_bins_per_level)
47115

48-
Vector{Float64}[],
116+
tree = SARSOPTree(
117+
sparse_pomdp, Vector{Float64}[],
49118
Vector{Int}[],
50119
corner_values, #upper_policy.util,
51120
Float64[],
@@ -63,8 +132,10 @@ function SARSOPTree(solver, pomdp::POMDP)
63132
Vector{Int}(),
64133
BitVector(),
65134
cache,
66-
PruneData(0,0,solver.prunethresh),
67-
AlphaVec{Int}[]
135+
PruneData(0, 0, solver.prunethresh),
136+
AlphaVec{Int}[],
137+
solver.use_binning,
138+
bin_manager
68139
)
69140
return insert_root!(solver, tree, _initialize_belief(pomdp, initialstate(pomdp)))
70141
end
@@ -93,7 +164,7 @@ function insert_root!(solver, tree::SARSOPTree, b)
93164
pomdp = tree.pomdp
94165

95166
Γ_lower = solve(solver.init_lower, pomdp)
96-
for (α,a) alphapairs(Γ_lower)
167+
for (α, a) alphapairs(Γ_lower)
97168
new_val = dot(α, b)
98169
push!(tree.Γ, AlphaVec(α, a))
99170
end
@@ -118,7 +189,7 @@ function update(tree::SARSOPTree, b_idx::Int, a, o)
118189
ba_idx = tree.b_children[b_idx][a]
119190
bp_idx = tree.ba_children[ba_idx][o]
120191
V̲, V̄ = if tree.is_terminal[bp_idx]
121-
0.,0.
192+
0.0, 0.0
122193
else
123194
lower_value(tree, tree.b[bp_idx]), upper_value(tree, tree.b[bp_idx])
124195
end
@@ -139,7 +210,7 @@ function add_belief!(tree::SARSOPTree, b, ba_idx::Int, o)
139210
push!(tree.is_terminal, terminal)
140211

141212
V̲, V̄ = if terminal
142-
0., 0.
213+
0.0, 0.0
143214
else
144215
lower_value(tree, b), upper_value(tree, b)
145216
end
@@ -163,6 +234,9 @@ function fill_belief!(tree::SARSOPTree, b_idx::Int)
163234
else
164235
fill_populated!(tree, b_idx)
165236
end
237+
if tree.use_binning
238+
update_bin_node!(tree, b_idx)
239+
end
166240
end
167241

168242
"""
@@ -186,8 +260,8 @@ function fill_populated!(tree::SARSOPTree, b_idx::Int)
186260
bp_idx, V̲, V̄ = update(tree, b_idx, a, o)
187261
b′ = tree.b[bp_idx]
188262
po = tree.poba[ba_idx][o]
189-
+= γ*po*
190-
+= γ*po*
263+
+= γ * po *
264+
+= γ * po *
191265
end
192266

193267
Qa_upper[a] =
@@ -219,7 +293,7 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
219293
tree.ba_children[ba_idx] = ba_children
220294

221295
n_b += N_OBS
222-
pred = dropzeros!(mul!(tree.cache.pred, pomdp.T[a],b))
296+
pred = dropzeros!(mul!(tree.cache.pred, pomdp.T[a], b))
223297
poba = zeros(Float64, N_OBS)
224298
Rba = belief_reward(tree, b, a)
225299

@@ -230,15 +304,15 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
230304
# belief update
231305
bp = corrector(pomdp, pred, a, o)
232306
po = sum(bp)
233-
if po > 0.
307+
if po > 0.0
234308
bp.nzval ./= po
235309
poba[o] = po
236310
end
237311

238312
bp_idx, V̲, V̄ = add_belief!(tree, bp, ba_idx, o)
239313

240-
+= γ*po*
241-
+= γ*po*
314+
+= γ * po *
315+
+= γ * po *
242316
end
243317
Qa_upper[a] =
244318
Qa_lower[a] =
@@ -249,3 +323,140 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
249323
tree.V_lower[b_idx] = lower_value(tree, tree.b[b_idx])
250324
tree.V_upper[b_idx] = maximum(tree.Qa_upper[b_idx])
251325
end
326+
327+
function initialize_bin_node!(tree::SARSOPTree, b_idx::Int)
328+
lb_val = tree.V_lower[b_idx]
329+
ub_val = tree.V_upper[b_idx]
330+
node_entropy = entropy(tree.b[b_idx])
331+
332+
for level_i in 1:tree.bm.num_levels
333+
ub_interval_idx = get_interval_idx(
334+
ub_val, tree.bm.lowest_ub, tree.bm.bin_levels_intervals[level_i][:ub],
335+
tree.bm.num_bins_per_level[level_i]
336+
)
337+
338+
entropy_interval_idx = get_interval_idx(
339+
node_entropy, 0.0, tree.bm.bin_levels_intervals[level_i][:entropy],
340+
tree.bm.num_bins_per_level[level_i]
341+
)
342+
343+
key = (ub_interval_idx, entropy_interval_idx)
344+
prev_error = 0.0
345+
346+
bin_count = tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx]
347+
if bin_count > 0
348+
err = tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] - lb_val
349+
prev_error = err * err
350+
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] += prev_error
351+
value = (tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] * bin_count + lb_val) / (bin_count + 1)
352+
tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx] += 1
353+
else
354+
err = ub_val - lb_val
355+
prev_error = err * err
356+
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] = prev_error
357+
value = lb_val
358+
tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx] = 1
359+
end
360+
tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] = value
361+
tree.bm.bin_levels_nodes[level_i][b_idx] = BinNode(key, prev_error)
362+
363+
end
364+
tree.bm.previous_lowerbound[b_idx] = lb_val
365+
end
366+
367+
function update_bin_node!(tree::SARSOPTree, b_idx::Int)
368+
lb_val = tree.V_lower[b_idx]
369+
up_val = tree.V_upper[b_idx]
370+
371+
if !haskey(tree.bm.bin_levels_nodes[1], b_idx)
372+
return initialize_bin_node!(tree, b_idx)
373+
end
374+
375+
for level_i in 1:tree.bm.num_levels
376+
node = tree.bm.bin_levels_nodes[level_i][b_idx]
377+
key = node.key
378+
ub_interval_idx, entropy_interval_idx = key
379+
prev_error = 0.0
380+
381+
bin_count = tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx]
382+
if bin_count == 1
383+
err = up_val - lb_val
384+
prev_error = err * err
385+
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] = prev_error
386+
else
387+
err = tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] - lb_val
388+
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] -= node.prev_error
389+
prev_error = err * err
390+
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] += prev_error
391+
end
392+
393+
tree.bm.bin_levels_nodes[level_i][b_idx] = BinNode(key, prev_error)
394+
tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] = (tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] * bin_count + lb_val - tree.bm.previous_lowerbound[b_idx]) / bin_count
395+
end
396+
tree.bm.previous_lowerbound[b_idx] = lb_val
397+
end
398+
399+
function get_bin_value(tree::SARSOPTree, b_idx::Int)
400+
401+
lb_val = tree.V_lower[b_idx]
402+
ub_val = tree.V_upper[b_idx]
403+
404+
node = tree.bm.bin_levels_nodes[1][b_idx]
405+
key = node.key
406+
ub_interval_idx, entropy_interval_idx = key
407+
if tree.bm.bin_levels[1].bin_count[ub_interval_idx, entropy_interval_idx] == 1
408+
return ub_val
409+
else
410+
smallest_error = Inf
411+
best_level = 0
412+
best_key = key
413+
for level_i in 1:tree.bm.num_levels
414+
node = tree.bm.bin_levels_nodes[level_i][b_idx]
415+
key = node.key
416+
ub_interval_idx, entropy_interval_idx = key
417+
if tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] + 1e-10 < smallest_error
418+
best_level = level_i
419+
smallest_error = tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx]
420+
best_key = key
421+
end
422+
end
423+
424+
best_ub_interval_idx, best_entropy_interval_idx = best_key
425+
best_value = tree.bm.bin_levels[best_level].bin_value[best_ub_interval_idx, best_entropy_interval_idx]
426+
if best_value > ub_val + 1e-10
427+
return ub_val
428+
elseif best_value + 1e-10 < lb_val
429+
return lb_val
430+
else
431+
return best_value
432+
end
433+
end
434+
end
435+
436+
function max_entropy(n::Int)
437+
return -1 * ((1.0 / n) * log(1.0 / n)) * n
438+
end
439+
440+
function entropy(b::AbstractVector)
441+
ent = 0.0
442+
for b_i in b
443+
b_i > 0 && (ent -= b_i * log(b_i))
444+
end
445+
return ent
446+
end
447+
448+
function entropy(b::SparseVector)
449+
ent = 0.0
450+
for b_i in b.nzval
451+
ent -= b_i * log(b_i)
452+
end
453+
return ent
454+
end
455+
456+
function get_interval_idx(value::Float64, lower::Float64, interval::Float64, num_intervals::Int)
457+
if interval == 0.0
458+
return 1
459+
end
460+
idx = Int(floor((value - lower) / interval) + 1)
461+
return clamp(idx, 1, num_intervals)
462+
end

0 commit comments

Comments
 (0)