@@ -4,6 +4,72 @@ mutable struct PruneData
4
4
prune_threshold:: Float64
5
5
end
6
6
7
+ struct BinData
8
+ bin_value:: Matrix{Float64}
9
+ bin_count:: Matrix{Int}
10
+ bin_error:: Matrix{Float64}
11
+ end
12
+
13
+ struct BinNode
14
+ key:: Tuple{Int,Int}
15
+ prev_error:: Float64
16
+ end
17
+
18
+ struct BinManager
19
+ lowest_ub:: Float64
20
+ num_levels:: Int
21
+ num_bins_per_level:: Vector{Int}
22
+ bin_levels_intervals:: Vector{NamedTuple{(:ub, :entropy),Tuple{Float64,Float64}}}
23
+ bin_levels_nodes:: Vector{Dict{Int,BinNode}}
24
+ bin_levels:: Vector{BinData}
25
+ previous_lowerbound:: Dict{Int,Float64}
26
+ end
27
+
28
+ function BinManager (Vs_upper:: Vector{Float64} , num_bins_per_level= [5 , 10 ])
29
+ num_levels = length (num_bins_per_level)
30
+ lowest_ub = minimum (Vs_upper)
31
+ highest_ub = maximum (Vs_upper)
32
+
33
+ # [level][:ub|:entropy] => value
34
+ bin_levels_intervals = Vector {NamedTuple{(:ub, :entropy),Tuple{Float64,Float64}}} (undef, num_levels)
35
+
36
+ # [level][b_idx][:key|:prev_error] => (ub_interval_idx, entropy_interval_idx)|previous_error
37
+ bin_levels_nodes = Vector {Dict{Int,BinNode}} (undef, num_levels)
38
+
39
+ # [level][:bin_value|:bin_count|:bin_error][(ub_interval_idx, entropy_interval_idx)] => Float64|Int|Float64
40
+ bin_levels = Vector {BinData} (undef, num_levels)
41
+
42
+ num_states = length (Vs_upper)
43
+ max_e = max_entropy (num_states)
44
+ for level_i in 1 : num_levels
45
+ num_bins = num_bins_per_level[level_i]
46
+
47
+ ub = (highest_ub - lowest_ub) / num_bins
48
+ ent = max_e / num_bins
49
+ bin_levels_intervals[level_i] = (ub= ub, entropy= ent)
50
+
51
+ bin_levels_nodes[level_i] = Dict {Int,BinNode} ()
52
+
53
+ bin_levels[level_i] = BinData (
54
+ zeros (Float64, num_bins, num_bins), # bin_value
55
+ zeros (Int, num_bins, num_bins), # bin_count
56
+ zeros (Float64, num_bins, num_bins) # bin_error
57
+ )
58
+ end
59
+
60
+ previous_lowerbound = Dict {Int,Float64} () # b_idx => lowerbound
61
+
62
+ return BinManager (
63
+ lowest_ub,
64
+ num_levels,
65
+ num_bins_per_level,
66
+ bin_levels_intervals,
67
+ bin_levels_nodes,
68
+ bin_levels,
69
+ previous_lowerbound
70
+ )
71
+ end
72
+
7
73
struct SARSOPTree
8
74
pomdp:: ModifiedSparseTabular
9
75
@@ -20,7 +86,7 @@ struct SARSOPTree
20
86
21
87
_discount:: Float64
22
88
is_terminal:: BitVector
23
- is_terminal_s:: SparseVector{Bool, Int}
89
+ is_terminal_s:: SparseVector{Bool,Int}
24
90
25
91
# do we need both b_pruned and ba_pruned? b_pruned might be enough
26
92
sampled:: Vector{Int} # b_idx
@@ -32,20 +98,23 @@ struct SARSOPTree
32
98
prune_data:: PruneData
33
99
34
100
Γ:: Vector{AlphaVec{Int}}
101
+
102
+ use_binning:: Bool
103
+ bm:: BinManager
35
104
end
36
105
37
106
38
- function SARSOPTree (solver, pomdp:: POMDP )
107
+ function SARSOPTree (solver, pomdp:: POMDP ; num_bins_per_level = [ 5 , 10 ] )
39
108
sparse_pomdp = ModifiedSparseTabular (pomdp)
40
109
cache = TreeCache (sparse_pomdp)
41
110
42
111
upper_policy = solve (solver. init_upper, sparse_pomdp)
43
112
corner_values = map (maximum, zip (upper_policy. alphas... ))
44
113
45
- tree = SARSOPTree (
46
- sparse_pomdp,
114
+ bin_manager = BinManager (corner_values, num_bins_per_level)
47
115
48
- Vector{Float64}[],
116
+ tree = SARSOPTree (
117
+ sparse_pomdp, Vector{Float64}[],
49
118
Vector{Int}[],
50
119
corner_values, # upper_policy.util,
51
120
Float64[],
@@ -63,8 +132,10 @@ function SARSOPTree(solver, pomdp::POMDP)
63
132
Vector {Int} (),
64
133
BitVector (),
65
134
cache,
66
- PruneData (0 ,0 ,solver. prunethresh),
67
- AlphaVec{Int}[]
135
+ PruneData (0 , 0 , solver. prunethresh),
136
+ AlphaVec{Int}[],
137
+ solver. use_binning,
138
+ bin_manager
68
139
)
69
140
return insert_root! (solver, tree, _initialize_belief (pomdp, initialstate (pomdp)))
70
141
end
@@ -93,7 +164,7 @@ function insert_root!(solver, tree::SARSOPTree, b)
93
164
pomdp = tree. pomdp
94
165
95
166
Γ_lower = solve (solver. init_lower, pomdp)
96
- for (α,a) ∈ alphapairs (Γ_lower)
167
+ for (α, a) ∈ alphapairs (Γ_lower)
97
168
new_val = dot (α, b)
98
169
push! (tree. Γ, AlphaVec (α, a))
99
170
end
@@ -118,7 +189,7 @@ function update(tree::SARSOPTree, b_idx::Int, a, o)
118
189
ba_idx = tree. b_children[b_idx][a]
119
190
bp_idx = tree. ba_children[ba_idx][o]
120
191
V̲, V̄ = if tree. is_terminal[bp_idx]
121
- 0. , 0.
192
+ 0.0 , 0.0
122
193
else
123
194
lower_value (tree, tree. b[bp_idx]), upper_value (tree, tree. b[bp_idx])
124
195
end
@@ -139,7 +210,7 @@ function add_belief!(tree::SARSOPTree, b, ba_idx::Int, o)
139
210
push! (tree. is_terminal, terminal)
140
211
141
212
V̲, V̄ = if terminal
142
- 0. , 0.
213
+ 0.0 , 0.0
143
214
else
144
215
lower_value (tree, b), upper_value (tree, b)
145
216
end
@@ -163,6 +234,9 @@ function fill_belief!(tree::SARSOPTree, b_idx::Int)
163
234
else
164
235
fill_populated! (tree, b_idx)
165
236
end
237
+ if tree. use_binning
238
+ update_bin_node! (tree, b_idx)
239
+ end
166
240
end
167
241
168
242
"""
@@ -186,8 +260,8 @@ function fill_populated!(tree::SARSOPTree, b_idx::Int)
186
260
bp_idx, V̲, V̄ = update (tree, b_idx, a, o)
187
261
b′ = tree. b[bp_idx]
188
262
po = tree. poba[ba_idx][o]
189
- Q̄ += γ* po * V̄
190
- Q̲ += γ* po * V̲
263
+ Q̄ += γ * po * V̄
264
+ Q̲ += γ * po * V̲
191
265
end
192
266
193
267
Qa_upper[a] = Q̄
@@ -219,7 +293,7 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
219
293
tree. ba_children[ba_idx] = ba_children
220
294
221
295
n_b += N_OBS
222
- pred = dropzeros! (mul! (tree. cache. pred, pomdp. T[a],b))
296
+ pred = dropzeros! (mul! (tree. cache. pred, pomdp. T[a], b))
223
297
poba = zeros (Float64, N_OBS)
224
298
Rba = belief_reward (tree, b, a)
225
299
@@ -230,15 +304,15 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
230
304
# belief update
231
305
bp = corrector (pomdp, pred, a, o)
232
306
po = sum (bp)
233
- if po > 0.
307
+ if po > 0.0
234
308
bp. nzval ./= po
235
309
poba[o] = po
236
310
end
237
311
238
312
bp_idx, V̲, V̄ = add_belief! (tree, bp, ba_idx, o)
239
313
240
- Q̄ += γ* po * V̄
241
- Q̲ += γ* po * V̲
314
+ Q̄ += γ * po * V̄
315
+ Q̲ += γ * po * V̲
242
316
end
243
317
Qa_upper[a] = Q̄
244
318
Qa_lower[a] = Q̲
@@ -249,3 +323,140 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
249
323
tree. V_lower[b_idx] = lower_value (tree, tree. b[b_idx])
250
324
tree. V_upper[b_idx] = maximum (tree. Qa_upper[b_idx])
251
325
end
326
+
327
+ function initialize_bin_node! (tree:: SARSOPTree , b_idx:: Int )
328
+ lb_val = tree. V_lower[b_idx]
329
+ ub_val = tree. V_upper[b_idx]
330
+ node_entropy = entropy (tree. b[b_idx])
331
+
332
+ for level_i in 1 : tree. bm. num_levels
333
+ ub_interval_idx = get_interval_idx (
334
+ ub_val, tree. bm. lowest_ub, tree. bm. bin_levels_intervals[level_i][:ub ],
335
+ tree. bm. num_bins_per_level[level_i]
336
+ )
337
+
338
+ entropy_interval_idx = get_interval_idx (
339
+ node_entropy, 0.0 , tree. bm. bin_levels_intervals[level_i][:entropy ],
340
+ tree. bm. num_bins_per_level[level_i]
341
+ )
342
+
343
+ key = (ub_interval_idx, entropy_interval_idx)
344
+ prev_error = 0.0
345
+
346
+ bin_count = tree. bm. bin_levels[level_i]. bin_count[ub_interval_idx, entropy_interval_idx]
347
+ if bin_count > 0
348
+ err = tree. bm. bin_levels[level_i]. bin_value[ub_interval_idx, entropy_interval_idx] - lb_val
349
+ prev_error = err * err
350
+ tree. bm. bin_levels[level_i]. bin_error[ub_interval_idx, entropy_interval_idx] += prev_error
351
+ value = (tree. bm. bin_levels[level_i]. bin_value[ub_interval_idx, entropy_interval_idx] * bin_count + lb_val) / (bin_count + 1 )
352
+ tree. bm. bin_levels[level_i]. bin_count[ub_interval_idx, entropy_interval_idx] += 1
353
+ else
354
+ err = ub_val - lb_val
355
+ prev_error = err * err
356
+ tree. bm. bin_levels[level_i]. bin_error[ub_interval_idx, entropy_interval_idx] = prev_error
357
+ value = lb_val
358
+ tree. bm. bin_levels[level_i]. bin_count[ub_interval_idx, entropy_interval_idx] = 1
359
+ end
360
+ tree. bm. bin_levels[level_i]. bin_value[ub_interval_idx, entropy_interval_idx] = value
361
+ tree. bm. bin_levels_nodes[level_i][b_idx] = BinNode (key, prev_error)
362
+
363
+ end
364
+ tree. bm. previous_lowerbound[b_idx] = lb_val
365
+ end
366
+
367
+ function update_bin_node! (tree:: SARSOPTree , b_idx:: Int )
368
+ lb_val = tree. V_lower[b_idx]
369
+ up_val = tree. V_upper[b_idx]
370
+
371
+ if ! haskey (tree. bm. bin_levels_nodes[1 ], b_idx)
372
+ return initialize_bin_node! (tree, b_idx)
373
+ end
374
+
375
+ for level_i in 1 : tree. bm. num_levels
376
+ node = tree. bm. bin_levels_nodes[level_i][b_idx]
377
+ key = node. key
378
+ ub_interval_idx, entropy_interval_idx = key
379
+ prev_error = 0.0
380
+
381
+ bin_count = tree. bm. bin_levels[level_i]. bin_count[ub_interval_idx, entropy_interval_idx]
382
+ if bin_count == 1
383
+ err = up_val - lb_val
384
+ prev_error = err * err
385
+ tree. bm. bin_levels[level_i]. bin_error[ub_interval_idx, entropy_interval_idx] = prev_error
386
+ else
387
+ err = tree. bm. bin_levels[level_i]. bin_value[ub_interval_idx, entropy_interval_idx] - lb_val
388
+ tree. bm. bin_levels[level_i]. bin_error[ub_interval_idx, entropy_interval_idx] -= node. prev_error
389
+ prev_error = err * err
390
+ tree. bm. bin_levels[level_i]. bin_error[ub_interval_idx, entropy_interval_idx] += prev_error
391
+ end
392
+
393
+ tree. bm. bin_levels_nodes[level_i][b_idx] = BinNode (key, prev_error)
394
+ tree. bm. bin_levels[level_i]. bin_value[ub_interval_idx, entropy_interval_idx] = (tree. bm. bin_levels[level_i]. bin_value[ub_interval_idx, entropy_interval_idx] * bin_count + lb_val - tree. bm. previous_lowerbound[b_idx]) / bin_count
395
+ end
396
+ tree. bm. previous_lowerbound[b_idx] = lb_val
397
+ end
398
+
399
+ function get_bin_value (tree:: SARSOPTree , b_idx:: Int )
400
+
401
+ lb_val = tree. V_lower[b_idx]
402
+ ub_val = tree. V_upper[b_idx]
403
+
404
+ node = tree. bm. bin_levels_nodes[1 ][b_idx]
405
+ key = node. key
406
+ ub_interval_idx, entropy_interval_idx = key
407
+ if tree. bm. bin_levels[1 ]. bin_count[ub_interval_idx, entropy_interval_idx] == 1
408
+ return ub_val
409
+ else
410
+ smallest_error = Inf
411
+ best_level = 0
412
+ best_key = key
413
+ for level_i in 1 : tree. bm. num_levels
414
+ node = tree. bm. bin_levels_nodes[level_i][b_idx]
415
+ key = node. key
416
+ ub_interval_idx, entropy_interval_idx = key
417
+ if tree. bm. bin_levels[level_i]. bin_error[ub_interval_idx, entropy_interval_idx] + 1e-10 < smallest_error
418
+ best_level = level_i
419
+ smallest_error = tree. bm. bin_levels[level_i]. bin_error[ub_interval_idx, entropy_interval_idx]
420
+ best_key = key
421
+ end
422
+ end
423
+
424
+ best_ub_interval_idx, best_entropy_interval_idx = best_key
425
+ best_value = tree. bm. bin_levels[best_level]. bin_value[best_ub_interval_idx, best_entropy_interval_idx]
426
+ if best_value > ub_val + 1e-10
427
+ return ub_val
428
+ elseif best_value + 1e-10 < lb_val
429
+ return lb_val
430
+ else
431
+ return best_value
432
+ end
433
+ end
434
+ end
435
+
436
+ function max_entropy (n:: Int )
437
+ return - 1 * ((1.0 / n) * log (1.0 / n)) * n
438
+ end
439
+
440
+ function entropy (b:: AbstractVector )
441
+ ent = 0.0
442
+ for b_i in b
443
+ b_i > 0 && (ent -= b_i * log (b_i))
444
+ end
445
+ return ent
446
+ end
447
+
448
+ function entropy (b:: SparseVector )
449
+ ent = 0.0
450
+ for b_i in b. nzval
451
+ ent -= b_i * log (b_i)
452
+ end
453
+ return ent
454
+ end
455
+
456
+ function get_interval_idx (value:: Float64 , lower:: Float64 , interval:: Float64 , num_intervals:: Int )
457
+ if interval == 0.0
458
+ return 1
459
+ end
460
+ idx = Int (floor ((value - lower) / interval) + 1 )
461
+ return clamp (idx, 1 , num_intervals)
462
+ end
0 commit comments