Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c167891

Browse files
TomasPegadopolvalente
andauthoredFeb 5, 2025··
feat: add wiener filter (#26)
* feat: add wiener filter * Apply suggestions from code review --------- Co-authored-by: Paulo Valente <[email protected]>
1 parent 7397ac6 commit c167891

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed
 

‎lib/nx_signal/filters.ex

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ defmodule NxSignal.Filters do
33
Common filter functions.
44
"""
55
import Nx.Defn
6+
import NxSignal.Convolution
67

78
@doc ~S"""
89
Performs a median filter on a tensor.
@@ -52,4 +53,83 @@ defmodule NxSignal.Filters do
5253
end
5354

5455
deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape)
56+
57+
@doc """
58+
Applies a Wiener filter to the given Nx tensor.
59+
60+
## Options
61+
62+
* `:kernel_size` - filter size given either a number or a tuple.
63+
If a number is given, a kernel with the given size, and same number of axes
64+
as the input tensor will be used. Defaults to `3`.
65+
* `:noise` - noise power, given as a scalar. This will be estimated based on the input tensor if `nil`. Defaults to `nil`.
66+
67+
## Examples
68+
69+
iex> t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
70+
iex> NxSignal.Filters.wiener(t, kernel_size: {2, 2}, noise: 10)
71+
#Nx.Tensor<
72+
f32[3][3]
73+
[
74+
[0.25, 0.75, 1.25],
75+
[1.25, 3.0, 4.0],
76+
[2.75, 6.0, 7.0]
77+
]
78+
>
79+
"""
80+
@doc type: :filters
81+
deftransform wiener(t, opts \\ []) do
82+
# Validate and extract options
83+
opts = Keyword.validate!(opts, noise: nil, kernel_size: 3)
84+
85+
rank = Nx.rank(t)
86+
kernel_size = Keyword.fetch!(opts, :kernel_size)
87+
noise = Keyword.fetch!(opts, :noise)
88+
89+
# Ensure `kernel_size` is a tuple
90+
kernel_size =
91+
cond do
92+
is_integer(kernel_size) -> Tuple.duplicate(kernel_size, rank)
93+
is_tuple(kernel_size) -> kernel_size
94+
true -> raise ArgumentError, "kernel_size must be an integer or tuple"
95+
end
96+
97+
# Convert `nil` noise to `0.0` so it's always a valid tensor
98+
noise_t = if is_nil(noise), do: Nx.tensor(0.0), else: Nx.tensor(noise)
99+
100+
# Compute filter window size
101+
size = Tuple.to_list(kernel_size) |> Enum.reduce(1, &*/2)
102+
103+
# Ensure the kernel is the same size as the filter window
104+
kernel = Nx.broadcast(1.0, kernel_size)
105+
106+
t
107+
|> Nx.as_type(:f64)
108+
|> wiener_n(kernel, noise_t, calculate_noise: is_nil(noise), size: size)
109+
|> Nx.as_type(Nx.type(t))
110+
end
111+
112+
defnp wiener_n(t, kernel, noise, opts) do
113+
size = opts[:size]
114+
115+
# Compute local mean using "same" mode in correlation
116+
l_mean = correlate(t, kernel, mode: :same) / size
117+
118+
# Compute local variance
119+
l_var =
120+
correlate(t ** 2, kernel, mode: :same)
121+
|> Nx.divide(size)
122+
|> Nx.subtract(l_mean ** 2)
123+
124+
# Ensure `noise` is a tensor to avoid `nil` issues in `defnp`
125+
noise =
126+
case opts[:calculate_noise] do
127+
true -> Nx.mean(l_var)
128+
false -> noise
129+
end
130+
131+
# Apply Wiener filter formula
132+
res = (t - l_mean) * (1 - noise / l_var)
133+
Nx.select(l_var < noise, l_mean, res + l_mean)
134+
end
55135
end

‎test/nx_signal/filters_test.exs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,129 @@ defmodule NxSignal.FiltersTest do
116116
)
117117
end
118118
end
119+
120+
describe "wiener/2" do
121+
test "performs n-dim wiener filter with calculated noise" do
122+
im =
123+
Nx.tensor(
124+
[
125+
[1.0, 2.0, 3.0, 4.0, 5.0],
126+
[6.0, 7.0, 8.0, 9.0, 10.0],
127+
[11.0, 12.0, 13.0, 14.0, 15.0]
128+
],
129+
type: :f64
130+
)
131+
132+
kernel_size = {3, 3}
133+
134+
expected =
135+
Nx.tensor(
136+
[
137+
[
138+
1.7777777777777777,
139+
3.0,
140+
3.6666666666666665,
141+
4.333333333333333,
142+
3.111111111111111
143+
],
144+
[4.3366520642506305, 7.0, 8.0, 9.0, 7.58637597408283],
145+
[
146+
4.692197051420351,
147+
7.261706150595039,
148+
8.748939779474131,
149+
10.157992415073023,
150+
9.813815742524799
151+
]
152+
],
153+
type: :f64
154+
)
155+
156+
assert NxSignal.Filters.wiener(im, kernel_size: kernel_size) == expected
157+
assert NxSignal.Filters.wiener(im, kernel_size: 3) == expected
158+
159+
assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size) ==
160+
Nx.tensor([
161+
[
162+
1.7777777910232544,
163+
3.0,
164+
3.6666667461395264,
165+
4.333333492279053,
166+
3.1111111640930176
167+
],
168+
[4.3366522789001465, 7.0, 8.0, 9.0, 7.586376190185547],
169+
[
170+
4.692196846008301,
171+
7.261706352233887,
172+
8.748939514160156,
173+
10.157992362976074,
174+
9.81381607055664
175+
]
176+
])
177+
end
178+
179+
test "performs n-dim wiener filter with parameterized noise" do
180+
im =
181+
Nx.tensor(
182+
[
183+
[1.0, 2.0, 3.0, 4.0, 5.0],
184+
[6.0, 7.0, 8.0, 9.0, 10.0],
185+
[11.0, 12.0, 13.0, 14.0, 15.0]
186+
],
187+
type: :f64
188+
)
189+
190+
kernel_size = {3, 3}
191+
192+
assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 10) ==
193+
Nx.tensor(
194+
[
195+
[
196+
1.7777777777777777,
197+
3.0,
198+
3.5882352941176467,
199+
4.238095238095238,
200+
3.7397034596375622
201+
],
202+
[5.193548387096774, 7.0, 8.0, 9.0, 8.829787234042554],
203+
[
204+
7.941747572815534,
205+
9.702702702702702,
206+
10.938931297709924,
207+
12.137254901960784,
208+
12.485549132947977
209+
]
210+
],
211+
type: :f64
212+
)
213+
214+
assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size, noise: 10) ==
215+
Nx.tensor([
216+
[
217+
1.7777777910232544,
218+
3.0,
219+
3.588235378265381,
220+
4.238095283508301,
221+
3.739703416824341
222+
],
223+
[5.193548202514648, 7.0, 8.0, 9.0, 8.829787254333496],
224+
[
225+
7.941747665405273,
226+
9.702702522277832,
227+
10.938931465148926,
228+
12.13725471496582,
229+
12.485548973083496
230+
]
231+
])
232+
233+
assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 0) ==
234+
Nx.tensor(
235+
[
236+
[1.0, 2.0, 3.0, 4.0, 5.0],
237+
[6.0, 7.0, 8.0, 9.0, 10.0],
238+
[11.0, 12.0, 13.0, 14.0, 15.0]
239+
],
240+
type: :f64
241+
)
242+
end
243+
end
119244
end

0 commit comments

Comments
 (0)
Please sign in to comment.