Skip to content

Commit da290e7

Browse files
Prune Changes (#19)
* Added pruning of alpha vectors for strictly dominated vectors at each call to `prune` * Chnage to not update Gamma size * Update to `prune_alpha!` * function name update * formatting update * Added tests for prune functions * Update src/prune.jl Co-authored-by: Tyler Becker <[email protected]> * Update src/prune.jl Co-authored-by: Tyler Becker <[email protected]> * Updated `prune_strictly_dominated!` to reduce allocations --------- Co-authored-by: Tyler Becker <[email protected]>
1 parent a7178b7 commit da290e7

File tree

3 files changed

+127
-45
lines changed

3 files changed

+127
-45
lines changed

src/prune.jl

Lines changed: 106 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ end
55

66
function prune!(solver::SARSOPSolver, tree::SARSOPTree)
77
prune!(tree)
8+
prune_strictly_dominated!(tree::SARSOPTree)
89
if should_prune_alphas(tree)
910
prune_alpha!(tree, solver.delta)
1011
end
@@ -48,60 +49,120 @@ function prune!(tree::SARSOPTree)
4849
end
4950
end
5051

51-
function belief_space_domination(α1, α2, B, δ)
52-
a1_dominant = true
53-
a2_dominant = true
54-
for b B
55-
!a1_dominant && !a2_dominant && return (false, false)
56-
δV = intersection_distance(α1, α2, b)
57-
δV δ && (a1_dominant = false)
58-
δV -δ && (a2_dominant = false)
59-
end
60-
return a1_dominant, a2_dominant
61-
end
62-
63-
@inline function intersection_distance(α1, α2, b)
64-
s = 0.0
52+
function intersection_distance(α1, α2, b)
6553
dot_sum = 0.0
66-
I,B = b.nzind, b.nzval
67-
@inbounds for _i eachindex(I)
54+
I, B = b.nzind, b.nzval
55+
for _i eachindex(I)
6856
i = I[_i]
69-
diff = α1[i] - α2[i]
70-
s += abs2(diff)
71-
dot_sum += diff*B[_i]
57+
dot_sum += (α1[i] - α2[i]) * B[_i]
58+
end
59+
s = 0.0
60+
for i eachindex(α1, α2)
61+
s += (α1[i] - α2[i])^2
7262
end
7363
return dot_sum / sqrt(s)
7464
end
7565

76-
function prune_alpha!(tree::SARSOPTree, δ)
66+
function prune_alpha!(tree::SARSOPTree, δ, eps=0.0)
7767
Γ = tree.Γ
78-
B_valid = tree.b[map(!,tree.b_pruned)]
79-
pruned = falses(length(Γ))
80-
81-
# checking if α_i dominates α_j
82-
for (i,α_i) enumerate(Γ)
83-
pruned[i] && continue
84-
for (j,α_j) enumerate(Γ)
85-
(j i || pruned[j]) && continue
86-
a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ)
87-
#=
88-
NOTE: α1 and α2 shouldn't technically be able to mutually dominate
89-
i.e. a1_dominant and a2_dominant should never both be true.
90-
But this does happen when α1 == α2 because intersection_distance returns NaN.
91-
Current impl prunes α2 without doing an equality check, removing
92-
the duplicate α. Could do equality check to short-circuit
93-
belief_space_domination which would speed things up if we have
94-
a lot of duplicates, but the equality check can slow things down
95-
if α's are sufficiently diverse.
96-
=#
97-
if a1_dominant
98-
pruned[j] = true
99-
elseif a2_dominant
100-
pruned[i] = true
101-
break
68+
B_valid = tree.b[map(!, tree.b_pruned)]
69+
70+
n_Γ = length(Γ)
71+
n_B = length(B_valid)
72+
73+
dominant_indices_bools = falses(n_Γ)
74+
dominant_vector_indices = Vector{Int}(undef, n_B)
75+
76+
# First, identify dominant alpha vectors
77+
for b_idx in 1:n_B
78+
max_value = -Inf
79+
max_index = -1
80+
for i in 1:n_Γ
81+
value = dot(Γ[i], B_valid[b_idx])
82+
if value > max_value
83+
max_value = value
84+
max_index = i
10285
end
10386
end
87+
dominant_indices_bools[max_index] = true
88+
dominant_vector_indices[b_idx] = max_index
10489
end
105-
deleteat!(Γ, pruned)
90+
91+
non_dominant_indices = findall(!, dominant_indices_bools)
92+
n_non_dom = length(non_dominant_indices)
93+
keep_non_dom = falses(n_non_dom)
94+
95+
for b_idx in 1:n_B
96+
dom_vec_idx = dominant_vector_indices[b_idx]
97+
for j in 1:n_non_dom
98+
non_dom_idx = non_dominant_indices[j]
99+
if keep_non_dom[j]
100+
continue
101+
end
102+
intx_dist = intersection_distance(Γ[dom_vec_idx], Γ[non_dom_idx], B_valid[b_idx])
103+
if !isnan(intx_dist) && (intx_dist + eps δ)
104+
keep_non_dom[j] = true
105+
end
106+
end
107+
end
108+
109+
non_dominant_indices = non_dominant_indices[.!keep_non_dom]
110+
deleteat!(Γ, non_dominant_indices)
106111
tree.prune_data.last_Γ_size = length(Γ)
107112
end
113+
114+
function strictly_dominates(α1, α2, eps)
115+
for ii in 1:length(α1)
116+
if α1[ii] < α2[ii] - eps
117+
return false
118+
end
119+
end
120+
return true
121+
end
122+
123+
function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10)
124+
Γ = tree.Γ
125+
Γ_new_idxs = Vector{Int}(undef, length(Γ))
126+
keep = trues(length(Γ))
127+
128+
idx_count = 0
129+
for (α_try_idx, α_try) in enumerate(Γ)
130+
dominated = false
131+
for jj in 1:idx_count
132+
α_in_idx = Γ_new_idxs[jj]
133+
α_in = Γ[α_in_idx]
134+
if strictly_dominates(α_try, α_in, eps)
135+
keep[jj] = false
136+
elseif strictly_dominates(α_in, α_try, eps)
137+
dominated = true
138+
break
139+
end
140+
end
141+
if !dominated
142+
new_idx_count = 0
143+
for jj in 1:idx_count
144+
if keep[jj]
145+
new_idx_count += 1
146+
Γ_new_idxs[new_idx_count] = Γ_new_idxs[jj]
147+
end
148+
end
149+
new_idx_count += 1
150+
Γ_new_idxs[new_idx_count] = α_try_idx
151+
idx_count = new_idx_count
152+
fill!(keep, true)
153+
end
154+
end
155+
156+
resize!(Γ_new_idxs, idx_count)
157+
158+
to_delete = trues(length(Γ))
159+
for idx in Γ_new_idxs
160+
to_delete[idx] = false
161+
end
162+
163+
for ii in length(Γ):-1:1
164+
if to_delete[ii]
165+
deleteat!(Γ, ii)
166+
end
167+
end
168+
end

test/prune.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@testset "prune" begin
2+
# NativeSARSOP.strictly_dominates
3+
a1 = [1.0, 2.0, 3.0]
4+
a2 = [1.0, 2.1, 2.9]
5+
a3 = [0.9, 1.9, 2.9]
6+
@test !NativeSARSOP.strictly_dominates(a1, a2, 1e-10)
7+
@test NativeSARSOP.strictly_dominates(a1, a1, 1e-10)
8+
@test NativeSARSOP.strictly_dominates(a1, a3, 1e-10)
9+
10+
# NativeSARSOP.intersection_distance
11+
b = SparseVector([1.0, 0.0])
12+
a1 = [1.0, 0.0]
13+
a2 = [0.0, 1.0]
14+
@test isapprox(NativeSARSOP.intersection_distance(a1, a2, b),
15+
sqrt(0.5^2 + 0.5^2), atol=1e-10)
16+
17+
b = SparseVector([0.5, 0.5])
18+
@test isapprox(NativeSARSOP.intersection_distance(a1, a2, b), 0.0, atol=1e-10)
19+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ include("sample.jl")
2727

2828
include("updater.jl")
2929

30+
include("prune.jl")
31+
3032
include("tree.jl")
3133

3234
@testset "Tiger POMDP" begin

0 commit comments

Comments
 (0)