Skip to content

Commit 25bb86b

Browse files
committed
Test complex input when appropriate for adjoint tests
1 parent eedba14 commit 25bb86b

File tree

1 file changed

+35
-28
lines changed

1 file changed

+35
-28
lines changed

test/runtests.jl

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,19 @@ end
201201

202202
@testset "output size" begin
203203
@testset "complex fft output size" begin
204-
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
205-
N = ndims(x)
206-
y = randn(size(x))
207-
for dims in unique((1, 1:N, N))
208-
P = plan_fft(x, dims)
209-
@test @inferred(AbstractFFTs.output_size(P)) == size(x)
210-
@test AbstractFFTs.output_size(P') == size(x)
211-
Pinv = plan_ifft(x)
212-
@test AbstractFFTs.output_size(Pinv) == size(x)
213-
@test AbstractFFTs.output_size(Pinv') == size(x)
204+
for x_shape in ((3,), (3, 4), (3, 4, 5))
205+
N = length(x_shape)
206+
real_x = randn(x_shape)
207+
complex_x = randn(ComplexF64, x_shape)
208+
for x in (real_x, complex_x)
209+
for dims in unique((1, 1:N, N))
210+
P = plan_fft(x, dims)
211+
@test @inferred(AbstractFFTs.output_size(P)) == size(x)
212+
@test AbstractFFTs.output_size(P') == size(x)
213+
Pinv = plan_ifft(x)
214+
@test AbstractFFTs.output_size(Pinv) == size(x)
215+
@test AbstractFFTs.output_size(Pinv') == size(x)
216+
end
214217
end
215218
end
216219
end
@@ -222,7 +225,7 @@ end
222225
Px_sz = size(P * x)
223226
@test AbstractFFTs.output_size(P) == Px_sz
224227
@test AbstractFFTs.output_size(P') == size(x)
225-
y = randn(Complex{Float64}, Px_sz)
228+
y = randn(ComplexF64, Px_sz)
226229
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
227230
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
228231
@test AbstractFFTs.output_size(Pinv') == size(y)
@@ -233,21 +236,25 @@ end
233236

234237
@testset "adjoint" begin
235238
@testset "complex fft adjoint" begin
236-
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
237-
N = ndims(x)
238-
y = randn(size(x))
239-
for dims in unique((1, 1:N, N))
240-
P = plan_fft(x, dims)
241-
@test (P')' === P # test adjoint of adjoint
242-
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
243-
@test dot(y, P * x) dot(P' * y, x) # test validity of adjoint
244-
@test dot(y, P \ x) dot(P' \ y, x)
245-
Pinv = plan_ifft(y)
246-
@test (Pinv')' * y == Pinv * y
247-
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
248-
@test dot(x, Pinv * y) dot(Pinv' * x, y)
249-
@test dot(x, Pinv \ y) dot(Pinv' \ x, y)
250-
@test_throws MethodError mul!(x, P', y)
239+
for x_shape in ((3,), (3, 4), (3, 4, 5))
240+
N = length(x_shape)
241+
real_x = randn(x_shape)
242+
complex_x = randn(ComplexF64, x_shape)
243+
y = randn(ComplexF64, x_shape)
244+
for x in (real_x, complex_x)
245+
for dims in unique((1, 1:N, N))
246+
P = plan_fft(x, dims)
247+
@test (P')' === P # test adjoint of adjoint
248+
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
249+
@test dot(y, P * x) dot(P' * y, x) # test validity of adjoint
250+
@test dot(y, P \ x) dot(P' \ y, x)
251+
Pinv = plan_ifft(y)
252+
@test (Pinv')' * y == Pinv * y
253+
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
254+
@test dot(x, Pinv * y) dot(Pinv' * x, y)
255+
@test dot(x, Pinv \ y) dot(Pinv' \ x, y)
256+
@test_throws MethodError mul!(x, P', y)
257+
end
251258
end
252259
end
253260
end
@@ -256,7 +263,7 @@ end
256263
N = ndims(x)
257264
for dims in unique((1, 1:N, N))
258265
P = plan_rfft(x, dims)
259-
y = randn(Complex{Float64}, size(P * x))
266+
y = randn(ComplexF64, size(P * x))
260267
@test (P')' * x == P * x
261268
@test size(P') == AbstractFFTs.output_size(P)
262269
@test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) dot(P' * y, x)
@@ -306,7 +313,7 @@ end
306313
for x_shape in ((2,), (2, 3), (3, 4, 5))
307314
N = length(x_shape)
308315
x = randn(x_shape)
309-
complex_x = x + randn(x_shape) * im
316+
complex_x = randn(ComplexF64, x_shape)
310317
for dims in unique((1, 1:N, N))
311318
# fft, ifft, bfft
312319
for f in (fft, ifft, bfft)

0 commit comments

Comments
 (0)