Skip to content

Commit a411b06

Browse files
author
Gaurav Arya
committed
Polish output_size
1 parent c7efe8d commit a411b06

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

src/definitions.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
255255
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
256256

257257
size(p::ScaledPlan) = size(p.p)
258-
output_size(p::ScaledPlan) = size(p)
258+
output_size(p::ScaledPlan) = output_size(p.p)
259259

260260
region(p::ScaledPlan) = region(p.p)
261261

@@ -587,7 +587,6 @@ const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInvers
587587

588588
function irfft_dim end
589589

590-
ProjectionStyle(p::Plan) = error("No projection style defined for plan")
591590
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
592591
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
593592
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p))

test/runtests.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,34 @@ end
198198
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
199199
end
200200

201+
@testset "output size" begin
202+
@testset "complex fft output size" begin
203+
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
204+
N = ndims(x)
205+
y = randn(size(x))
206+
for dims in unique((1, 1:N, N))
207+
P = plan_fft(x, dims)
208+
@test AbstractFFTs.output_size(P) == size(P * x)
209+
Pinv = plan_ifft(x)
210+
@test AbstractFFTs.output_size(Pinv) == size(Pinv * x)
211+
end
212+
end
213+
end
214+
@testset "real fft output size" begin
215+
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
216+
N = ndims(x)
217+
for dims in unique((1, 1:N, N))
218+
P = plan_rfft(x, dims)
219+
Px_sz = size(P * x)
220+
@test AbstractFFTs.output_size(P) == Px_sz
221+
y = randn(Px_sz) .+ randn(Px_sz) * im
222+
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
223+
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
224+
end
225+
end
226+
end
227+
end
228+
201229
@testset "adjoint" begin
202230
@testset "complex fft adjoint" begin
203231
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
@@ -217,13 +245,13 @@ end
217245
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
218246
N = ndims(x)
219247
for dims in unique((1, 1:N, N))
220-
P = plan_rfft(similar(x), dims)
248+
P = plan_rfft(x, dims)
221249
y_real = randn(size(P * x))
222250
y_imag = randn(size(P * x))
223251
y = y_real .+ y_imag .* im
224252
@test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) dot(P' * y, x)
225253
@test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) dot(P' * y, x)
226-
Pinv = plan_irfft(similar(y), size(x)[first(dims)], dims)
254+
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
227255
@test dot(x, Pinv * y) dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x))
228256
@test_broken dot(x, Pinv \ y) dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x))
229257
end

0 commit comments

Comments
 (0)