|
5 | 5 |
|
6 | 6 | function prune!(solver::SARSOPSolver, tree::SARSOPTree)
|
7 | 7 | prune!(tree)
|
| 8 | + prune_strictly_dominated!(tree::SARSOPTree) |
8 | 9 | if should_prune_alphas(tree)
|
9 | 10 | prune_alpha!(tree, solver.delta)
|
10 | 11 | end
|
@@ -48,60 +49,120 @@ function prune!(tree::SARSOPTree)
|
48 | 49 | end
|
49 | 50 | end
|
50 | 51 |
|
51 |
| -function belief_space_domination(α1, α2, B, δ) |
52 |
| - a1_dominant = true |
53 |
| - a2_dominant = true |
54 |
| - for b ∈ B |
55 |
| - !a1_dominant && !a2_dominant && return (false, false) |
56 |
| - δV = intersection_distance(α1, α2, b) |
57 |
| - δV ≤ δ && (a1_dominant = false) |
58 |
| - δV ≥ -δ && (a2_dominant = false) |
59 |
| - end |
60 |
| - return a1_dominant, a2_dominant |
61 |
| -end |
62 |
| - |
63 |
| -@inline function intersection_distance(α1, α2, b) |
64 |
| - s = 0.0 |
| 52 | +function intersection_distance(α1, α2, b) |
65 | 53 | dot_sum = 0.0
|
66 |
| - I,B = b.nzind, b.nzval |
67 |
| - @inbounds for _i ∈ eachindex(I) |
| 54 | + I, B = b.nzind, b.nzval |
| 55 | + for _i ∈ eachindex(I) |
68 | 56 | i = I[_i]
|
69 |
| - diff = α1[i] - α2[i] |
70 |
| - s += abs2(diff) |
71 |
| - dot_sum += diff*B[_i] |
| 57 | + dot_sum += (α1[i] - α2[i]) * B[_i] |
| 58 | + end |
| 59 | + s = 0.0 |
| 60 | + for i ∈ eachindex(α1, α2) |
| 61 | + s += (α1[i] - α2[i])^2 |
72 | 62 | end
|
73 | 63 | return dot_sum / sqrt(s)
|
74 | 64 | end
|
75 | 65 |
|
76 |
| -function prune_alpha!(tree::SARSOPTree, δ) |
| 66 | +function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) |
77 | 67 | Γ = tree.Γ
|
78 |
| - B_valid = tree.b[map(!,tree.b_pruned)] |
79 |
| - pruned = falses(length(Γ)) |
80 |
| - |
81 |
| - # checking if α_i dominates α_j |
82 |
| - for (i,α_i) ∈ enumerate(Γ) |
83 |
| - pruned[i] && continue |
84 |
| - for (j,α_j) ∈ enumerate(Γ) |
85 |
| - (j ≤ i || pruned[j]) && continue |
86 |
| - a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ) |
87 |
| - #= |
88 |
| - NOTE: α1 and α2 shouldn't technically be able to mutually dominate |
89 |
| - i.e. a1_dominant and a2_dominant should never both be true. |
90 |
| - But this does happen when α1 == α2 because intersection_distance returns NaN. |
91 |
| - Current impl prunes α2 without doing an equality check, removing |
92 |
| - the duplicate α. Could do equality check to short-circuit |
93 |
| - belief_space_domination which would speed things up if we have |
94 |
| - a lot of duplicates, but the equality check can slow things down |
95 |
| - if α's are sufficiently diverse. |
96 |
| - =# |
97 |
| - if a1_dominant |
98 |
| - pruned[j] = true |
99 |
| - elseif a2_dominant |
100 |
| - pruned[i] = true |
101 |
| - break |
| 68 | + B_valid = tree.b[map(!, tree.b_pruned)] |
| 69 | + |
| 70 | + n_Γ = length(Γ) |
| 71 | + n_B = length(B_valid) |
| 72 | + |
| 73 | + dominant_indices_bools = falses(n_Γ) |
| 74 | + dominant_vector_indices = Vector{Int}(undef, n_B) |
| 75 | + |
| 76 | + # First, identify dominant alpha vectors |
| 77 | + for b_idx in 1:n_B |
| 78 | + max_value = -Inf |
| 79 | + max_index = -1 |
| 80 | + for i in 1:n_Γ |
| 81 | + value = dot(Γ[i], B_valid[b_idx]) |
| 82 | + if value > max_value |
| 83 | + max_value = value |
| 84 | + max_index = i |
102 | 85 | end
|
103 | 86 | end
|
| 87 | + dominant_indices_bools[max_index] = true |
| 88 | + dominant_vector_indices[b_idx] = max_index |
104 | 89 | end
|
105 |
| - deleteat!(Γ, pruned) |
| 90 | + |
| 91 | + non_dominant_indices = findall(!, dominant_indices_bools) |
| 92 | + n_non_dom = length(non_dominant_indices) |
| 93 | + keep_non_dom = falses(n_non_dom) |
| 94 | + |
| 95 | + for b_idx in 1:n_B |
| 96 | + dom_vec_idx = dominant_vector_indices[b_idx] |
| 97 | + for j in 1:n_non_dom |
| 98 | + non_dom_idx = non_dominant_indices[j] |
| 99 | + if keep_non_dom[j] |
| 100 | + continue |
| 101 | + end |
| 102 | + intx_dist = intersection_distance(Γ[dom_vec_idx], Γ[non_dom_idx], B_valid[b_idx]) |
| 103 | + if !isnan(intx_dist) && (intx_dist + eps ≤ δ) |
| 104 | + keep_non_dom[j] = true |
| 105 | + end |
| 106 | + end |
| 107 | + end |
| 108 | + |
| 109 | + non_dominant_indices = non_dominant_indices[.!keep_non_dom] |
| 110 | + deleteat!(Γ, non_dominant_indices) |
106 | 111 | tree.prune_data.last_Γ_size = length(Γ)
|
107 | 112 | end
|
| 113 | + |
| 114 | +function strictly_dominates(α1, α2, eps) |
| 115 | + for ii in 1:length(α1) |
| 116 | + if α1[ii] < α2[ii] - eps |
| 117 | + return false |
| 118 | + end |
| 119 | + end |
| 120 | + return true |
| 121 | +end |
| 122 | + |
| 123 | +function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) |
| 124 | + Γ = tree.Γ |
| 125 | + Γ_new_idxs = Vector{Int}(undef, length(Γ)) |
| 126 | + keep = trues(length(Γ)) |
| 127 | + |
| 128 | + idx_count = 0 |
| 129 | + for (α_try_idx, α_try) in enumerate(Γ) |
| 130 | + dominated = false |
| 131 | + for jj in 1:idx_count |
| 132 | + α_in_idx = Γ_new_idxs[jj] |
| 133 | + α_in = Γ[α_in_idx] |
| 134 | + if strictly_dominates(α_try, α_in, eps) |
| 135 | + keep[jj] = false |
| 136 | + elseif strictly_dominates(α_in, α_try, eps) |
| 137 | + dominated = true |
| 138 | + break |
| 139 | + end |
| 140 | + end |
| 141 | + if !dominated |
| 142 | + new_idx_count = 0 |
| 143 | + for jj in 1:idx_count |
| 144 | + if keep[jj] |
| 145 | + new_idx_count += 1 |
| 146 | + Γ_new_idxs[new_idx_count] = Γ_new_idxs[jj] |
| 147 | + end |
| 148 | + end |
| 149 | + new_idx_count += 1 |
| 150 | + Γ_new_idxs[new_idx_count] = α_try_idx |
| 151 | + idx_count = new_idx_count |
| 152 | + fill!(keep, true) |
| 153 | + end |
| 154 | + end |
| 155 | + |
| 156 | + resize!(Γ_new_idxs, idx_count) |
| 157 | + |
| 158 | + to_delete = trues(length(Γ)) |
| 159 | + for idx in Γ_new_idxs |
| 160 | + to_delete[idx] = false |
| 161 | + end |
| 162 | + |
| 163 | + for ii in length(Γ):-1:1 |
| 164 | + if to_delete[ii] |
| 165 | + deleteat!(Γ, ii) |
| 166 | + end |
| 167 | + end |
| 168 | +end |
0 commit comments