Skip to content

Commit 3e7d412

Browse files
authored
Merge pull request #65 from JuliaMath/dw/region
Add `region(::Plan)` for accessing transformed region
2 parents 4733cd1 + 8c4dcd9 commit 3e7d412

File tree

6 files changed

+26
-5
lines changed

6 files changed

+26
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AbstractFFTs"
22
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3-
version = "1.1.0"
3+
version = "1.2.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ To define a new FFT implementation in your own module, you should
2323
inverse plan.
2424

2525
* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
26-
`x` and some set of dimensions `region`.
26+
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
2727

2828
* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
2929
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ AbstractFFTs.brfft
1919
AbstractFFTs.plan_rfft
2020
AbstractFFTs.plan_brfft
2121
AbstractFFTs.plan_irfft
22+
AbstractFFTs.fftdims
2223
AbstractFFTs.fftshift
2324
AbstractFFTs.fftshift!
2425
AbstractFFTs.ifftshift

src/AbstractFFTs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import ChainRulesCore
55
export fft, ifft, bfft, fft!, ifft!, bfft!,
66
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
77
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
8-
fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
8+
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
99

1010
include("definitions.jl")
1111
include("chainrules.jl")

src/definitions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d]
1515
ndims(p::Plan) = length(size(p))
1616
length(p::Plan) = prod(size(p))::Int
1717

18+
"""
19+
fftdims(p::Plan)
20+
21+
Return an iterable of the dimensions that are transformed by the FFT plan `p`.
22+
23+
# Implementation
24+
25+
For legacy reasons, the default definition of `fftdims` returns `p.region`.
26+
Hence this method should be implemented only for `Plan` subtypes that do not store the transformed dimensions in a field named `region`.
27+
"""
28+
fftdims(p::Plan) = p.region
29+
1830
fftfloat(x) = _fftfloat(float(x))
1931
_fftfloat(::Type{T}) where {T<:BlasReal} = T
2032
_fftfloat(::Type{Float16}) = Float32
@@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
243255

244256
size(p::ScaledPlan) = size(p.p)
245257

258+
fftdims(p::ScaledPlan) = fftdims(p.p)
259+
246260
show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p)
247261
summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))
248262

test/runtests.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,21 @@ end
6060
@test eltype(P) === ComplexF64
6161
@test P * x fftw_fft
6262
@test P \ (P * x) x
63+
@test fftdims(P) == dims
6364

6465
fftw_bfft = complex.(size(x, dims) .* x)
6566
@test AbstractFFTs.bfft(y, dims) fftw_bfft
6667
P = plan_bfft(x, dims)
6768
@test P * y fftw_bfft
6869
@test P \ (P * y) y
70+
@test fftdims(P) == dims
6971

7072
fftw_ifft = complex.(x)
7173
@test AbstractFFTs.ifft(y, dims) fftw_ifft
7274
P = plan_ifft(x, dims)
7375
@test P * y fftw_ifft
7476
@test P \ (P * y) y
77+
@test fftdims(P) == dims
7578

7679
# real FFT
7780
fftw_rfft = fftw_fft[
@@ -84,18 +87,21 @@ end
8487
@test eltype(P) === Int
8588
@test P * x fftw_rfft
8689
@test P \ (P * x) x
90+
@test fftdims(P) == dims
8791

8892
fftw_brfft = complex.(size(x, dims) .* x)
8993
@test AbstractFFTs.brfft(ry, size(x, dims), dims) fftw_brfft
9094
P = plan_brfft(ry, size(x, dims), dims)
9195
@test P * ry fftw_brfft
9296
@test P \ (P * ry) ry
93-
97+
@test fftdims(P) == dims
98+
9499
fftw_irfft = complex.(x)
95100
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
96101
P = plan_irfft(ry, size(x, dims), dims)
97102
@test P * ry fftw_irfft
98103
@test P \ (P * ry) ry
104+
@test fftdims(P) == dims
99105
end
100106
end
101107

@@ -187,7 +193,7 @@ end
187193
# normalization should be inferable even if region is only inferred as ::Any,
188194
# need to wrap in another function to test this (note that p.region::Any for
189195
# p::TestPlan)
190-
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region)
196+
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
191197
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
192198
end
193199

0 commit comments

Comments
 (0)