205
205
y = randn (size (x))
206
206
for dims in unique ((1 , 1 : N, N))
207
207
P = plan_fft (x, dims)
208
- @test AbstractFFTs. output_size (P) == size (P * x)
208
+ @test AbstractFFTs. output_size (P) == size (x)
209
+ @test AbstractFFTs. output_size (P' ) == size (x)
209
210
Pinv = plan_ifft (x)
210
- @test AbstractFFTs. output_size (Pinv) == size (Pinv * x)
211
+ @test AbstractFFTs. output_size (Pinv) == size (x)
212
+ @test AbstractFFTs. output_size (Pinv' ) == size (x)
211
213
end
212
214
end
213
215
end
218
220
P = plan_rfft (x, dims)
219
221
Px_sz = size (P * x)
220
222
@test AbstractFFTs. output_size (P) == Px_sz
223
+ @test AbstractFFTs. output_size (P' ) == size (x)
221
224
y = randn (Px_sz) .+ randn (Px_sz) * im
222
225
Pinv = plan_irfft (y, size (x)[first (dims)], dims)
223
226
@test AbstractFFTs. output_size (Pinv) == size (Pinv * y)
227
+ @test AbstractFFTs. output_size (Pinv' ) == size (y)
224
228
end
225
229
end
226
230
end
233
237
y = randn (size (x))
234
238
for dims in unique ((1 , 1 : N, N))
235
239
P = plan_fft (x, dims)
240
+ @test (P' )' * x == P * x # test adjoint of adjoint
236
241
@test dot (y, P * x) ≈ dot (P' * y, x)
237
242
@test_broken dot (y, P \ x) ≈ dot (P' \ y, x)
238
- Pinv = plan_ifft (x)
243
+ Pinv = plan_ifft (y)
244
+ @test (Pinv' )' * y == Pinv * y # test adjoint of adjoint
239
245
@test dot (x, Pinv * y) ≈ dot (Pinv' * x, y)
240
246
@test_broken dot (x, Pinv \ y) ≈ dot (Pinv' \ x, y)
241
247
end
@@ -246,12 +252,14 @@ end
246
252
N = ndims (x)
247
253
for dims in unique ((1 , 1 : N, N))
248
254
P = plan_rfft (x, dims)
255
+ @test (P' )' * x == P * x
249
256
y_real = randn (size (P * x))
250
257
y_imag = randn (size (P * x))
251
258
y = y_real .+ y_imag .* im
252
259
@test dot (y_real, real .(P * x)) + dot (y_imag, imag .(P * x)) ≈ dot (P' * y, x)
253
260
@test_broken dot (y_real, real .(P \ x)) + dot (y_imag, imag .(P \ x)) ≈ dot (P' * y, x)
254
261
Pinv = plan_irfft (y, size (x)[first (dims)], dims)
262
+ @test (Pinv' )' * y == Pinv * y
255
263
@test dot (x, Pinv * y) ≈ dot (y_real, real .(Pinv' * x)) + dot (y_imag, imag .(Pinv' * x))
256
264
@test_broken dot (x, Pinv \ y) ≈ dot (y_real, real .(Pinv' \ x)) + dot (y_imag, imag .(Pinv' \ x))
257
265
end
@@ -284,20 +292,27 @@ end
284
292
N = ndims (x)
285
293
complex_x = complex .(x)
286
294
for dims in unique ((1 , 1 : N, N))
295
+ # fft, ifft, bfft
287
296
for f in (fft, ifft, bfft)
288
297
test_frule (f, x, dims)
289
298
test_rrule (f, x, dims)
290
299
test_frule (f, complex_x, dims)
291
300
test_rrule (f, complex_x, dims)
292
301
end
293
-
294
302
for pf in (plan_fft, plan_ifft, plan_bfft)
295
303
test_frule (* , pf (x, dims) ⊢ NoTangent (), x)
296
304
test_rrule (* , pf (x, dims) ⊢ NoTangent (), x)
297
305
test_frule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
298
306
test_rrule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
299
307
end
300
308
309
+ # rfft
310
+ test_frule (rfft, x, dims)
311
+ test_rrule (rfft, x, dims)
312
+ test_frule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
313
+ test_rrule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
314
+
315
+ # irfft, brfft
301
316
for f in (irfft, brfft)
302
317
for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
303
318
test_frule (f, x, d, dims)
@@ -306,14 +321,12 @@ end
306
321
test_rrule (f, complex_x, d, dims)
307
322
end
308
323
end
309
-
310
324
for pf in (plan_irfft, plan_brfft)
311
325
for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
312
326
test_frule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
313
327
test_rrule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
314
328
end
315
329
end
316
-
317
330
end
318
331
end
319
332
end
0 commit comments