Skip to content

Commit de38539

Browse files
committed
feat: add wiener filter
1 parent 7397ac6 commit de38539

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed

lib/nx_signal/filters.ex

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ defmodule NxSignal.Filters do
33
Common filter functions.
44
"""
55
import Nx.Defn
6+
import NxSignal.Convolution
67

78
@doc ~S"""
89
Performs a median filter on a tensor.
@@ -52,4 +53,80 @@ defmodule NxSignal.Filters do
5253
end
5354

5455
deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape)
56+
57+
@doc """
58+
Applies a Wiener filter to the given Nx tensor.
59+
60+
## Options
61+
* `:kernel_size` - filter size (scalar or tuple). Defaults to `3`.
62+
* `:noise` - noise power, auto-estimated if `nil`. Defaults to `nil`.
63+
64+
## Examples
65+
66+
iex> t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
67+
iex> NxSignal.Filters.wiener(t, kernel_size: {2, 2}, noise: 10)
68+
#Nx.Tensor<
69+
f32[3][3]
70+
[
71+
[0.25, 0.75, 1.25],
72+
[1.25, 3.0, 4.0],
73+
[2.75, 6.0, 7.0]
74+
]
75+
>
76+
"""
77+
@doc type: :filters
78+
deftransform wiener(t, opts \\ []) do
79+
# Validate and extract options
80+
opts = Keyword.validate!(opts, noise: nil, kernel_size: 3)
81+
82+
rank = Nx.rank(t)
83+
kernel_size = Keyword.fetch!(opts, :kernel_size)
84+
noise = Keyword.fetch!(opts, :noise)
85+
86+
# Ensure `kernel_size` is a tuple
87+
kernel_size =
88+
cond do
89+
is_integer(kernel_size) -> Tuple.duplicate(kernel_size, rank)
90+
is_tuple(kernel_size) -> kernel_size
91+
true -> raise ArgumentError, "kernel_size must be an integer or tuple"
92+
end
93+
94+
# Convert `nil` noise to `0.0` so it's always a valid tensor
95+
noise_t = if is_nil(noise), do: Nx.tensor(0.0), else: Nx.tensor(noise)
96+
97+
# Compute filter window size
98+
size = Tuple.to_list(kernel_size) |> Enum.reduce(1, &*/2)
99+
100+
# Ensure the kernel is the same size as the filter window
101+
kernel = Nx.broadcast(1.0, kernel_size)
102+
103+
t
104+
|> Nx.as_type(:f64)
105+
|> wiener_n(kernel, noise_t, calculate_noise: is_nil(noise), size: size)
106+
|> Nx.as_type(Nx.type(t))
107+
end
108+
109+
defnp wiener_n(t, kernel, noise, opts) do
110+
size = opts[:size]
111+
112+
# Compute local mean using "same" mode in correlation
113+
l_mean = correlate(t, kernel, mode: :same) / size
114+
115+
# Compute local variance
116+
l_var =
117+
correlate(t ** 2, kernel, mode: :same)
118+
|> Nx.divide(size)
119+
|> Nx.subtract(l_mean ** 2)
120+
121+
# Ensure `noise` is a tensor to avoid `nil` issues in `defnp`
122+
noise =
123+
case opts[:calculate_noise] do
124+
true -> Nx.mean(l_var)
125+
false -> noise
126+
end
127+
128+
# Apply Wiener filter formula
129+
res = (t - l_mean) * (1 - noise / l_var)
130+
Nx.select(l_var < noise, l_mean, res + l_mean)
131+
end
55132
end

test/nx_signal/filters_test.exs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,129 @@ defmodule NxSignal.FiltersTest do
116116
)
117117
end
118118
end
119+
120+
describe "wiener/2" do
121+
test "performs n-dim wiener filter with calculated noise" do
122+
im =
123+
Nx.tensor(
124+
[
125+
[1.0, 2.0, 3.0, 4.0, 5.0],
126+
[6.0, 7.0, 8.0, 9.0, 10.0],
127+
[11.0, 12.0, 13.0, 14.0, 15.0]
128+
],
129+
type: :f64
130+
)
131+
132+
kernel_size = {3, 3}
133+
134+
expected =
135+
Nx.tensor(
136+
[
137+
[
138+
1.7777777777777777,
139+
3.0,
140+
3.6666666666666665,
141+
4.333333333333333,
142+
3.111111111111111
143+
],
144+
[4.3366520642506305, 7.0, 8.0, 9.0, 7.58637597408283],
145+
[
146+
4.692197051420351,
147+
7.261706150595039,
148+
8.748939779474131,
149+
10.157992415073023,
150+
9.813815742524799
151+
]
152+
],
153+
type: :f64
154+
)
155+
156+
assert NxSignal.Filters.wiener(im, kernel_size: kernel_size) == expected
157+
assert NxSignal.Filters.wiener(im, kernel_size: 3) == expected
158+
159+
assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size) ==
160+
Nx.tensor([
161+
[
162+
1.7777777910232544,
163+
3.0,
164+
3.6666667461395264,
165+
4.333333492279053,
166+
3.1111111640930176
167+
],
168+
[4.3366522789001465, 7.0, 8.0, 9.0, 7.586376190185547],
169+
[
170+
4.692196846008301,
171+
7.261706352233887,
172+
8.748939514160156,
173+
10.157992362976074,
174+
9.81381607055664
175+
]
176+
])
177+
end
178+
179+
test "performs n-dim wiener filter with parameterized noise" do
180+
im =
181+
Nx.tensor(
182+
[
183+
[1.0, 2.0, 3.0, 4.0, 5.0],
184+
[6.0, 7.0, 8.0, 9.0, 10.0],
185+
[11.0, 12.0, 13.0, 14.0, 15.0]
186+
],
187+
type: :f64
188+
)
189+
190+
kernel_size = {3, 3}
191+
192+
assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 10) ==
193+
Nx.tensor(
194+
[
195+
[
196+
1.7777777777777777,
197+
3.0,
198+
3.5882352941176467,
199+
4.238095238095238,
200+
3.7397034596375622
201+
],
202+
[5.193548387096774, 7.0, 8.0, 9.0, 8.829787234042554],
203+
[
204+
7.941747572815534,
205+
9.702702702702702,
206+
10.938931297709924,
207+
12.137254901960784,
208+
12.485549132947977
209+
]
210+
],
211+
type: :f64
212+
)
213+
214+
assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size, noise: 10) ==
215+
Nx.tensor([
216+
[
217+
1.7777777910232544,
218+
3.0,
219+
3.588235378265381,
220+
4.238095283508301,
221+
3.739703416824341
222+
],
223+
[5.193548202514648, 7.0, 8.0, 9.0, 8.829787254333496],
224+
[
225+
7.941747665405273,
226+
9.702702522277832,
227+
10.938931465148926,
228+
12.13725471496582,
229+
12.485548973083496
230+
]
231+
])
232+
233+
assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 0) ==
234+
Nx.tensor(
235+
[
236+
[1.0, 2.0, 3.0, 4.0, 5.0],
237+
[6.0, 7.0, 8.0, 9.0, 10.0],
238+
[11.0, 12.0, 13.0, 14.0, 15.0]
239+
],
240+
type: :f64
241+
)
242+
end
243+
end
119244
end

0 commit comments

Comments
 (0)