Skip to content

Commit 8601a92

Browse files
gaurav-aryadevmotionsethaxen
authored
Chain rules for FFT plans via AdjointPlans (#67)
* Implement AdjointPlans * Implement chain rules for FFT plans * Test plan adjoints and AD rules * Apply suggestions from adjoint plan code review Co-authored-by: David Widmann <[email protected]> * Include irrft_dim in RealInverseProjectionStyle Co-authored-by: David Widmann <[email protected]> * update to new fftdims interface * fix broken tests * Explicitly don't support mul! for adjoint plans * Document adjoint plans * remove incorrectly thrown error * Update adjoint plan docs * Update adjoint docs * Fix typos * tweak adjoint doc string * Tweaks to adjoint description * Immutable AdjointPlan * Add rules and tests for ScaledPlan * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * More tweaks to address code review * Restrict to T<:Real for rfft adjoint * Get type T correct for test irfft * Test complex input when appropriate for adjoint tests * Add plan_inv implementation for adjoint plan and test it * Apply suggestions from code review Co-authored-by: Seth Axen <[email protected]> * Apply suggestions from code review * Test in-place plans --------- Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Seth Axen <[email protected]>
1 parent b5109aa commit 8601a92

File tree

9 files changed

+349
-55
lines changed

9 files changed

+349
-55
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ julia = "^1.0"
1919
[extras]
2020
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2121
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2223
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2324
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2425
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2526

2627
[targets]
27-
test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"]
28+
test = ["ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]

README.md

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,5 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)`
1616

1717
## Developer information
1818

19-
To define a new FFT implementation in your own module, you should
19+
To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation).
2020

21-
* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
22-
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
23-
inverse plan.
24-
25-
* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
26-
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
27-
28-
* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
29-
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.
30-
31-
* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method.
32-
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.
33-
34-
* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
35-
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.
36-
37-
* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.
38-
39-
The normalization convention for your FFT should be that it computes $y_k = \sum_j \exp\(-2 \pi i \cdot \frac{j k}{n}\) x_j$
40-
for a transform of length $n$, and the "backwards" (unnormalized inverse) transform computes the same thing but with
41-
$\exp\(+2 \pi i \cdot \frac{j k}{n}\)$.

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
34

45
[compat]

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ AbstractFFTs.plan_rfft
2020
AbstractFFTs.plan_brfft
2121
AbstractFFTs.plan_irfft
2222
AbstractFFTs.fftdims
23+
Base.adjoint
2324
AbstractFFTs.fftshift
2425
AbstractFFTs.fftshift!
2526
AbstractFFTs.ifftshift

docs/src/implementations.md

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,31 @@ The following packages extend the functionality provided by AbstractFFTs:
1111

1212
## Defining a new implementation
1313

14-
Implementations should implement `LinearAlgebra.mul!(Y, plan, X)` (or
15-
`A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) so as to support
16-
pre-allocated output arrays.
17-
We don't define `*` in terms of `mul!` generically here, however, because
18-
of subtleties for in-place and real FFT plans.
19-
20-
To support `inv`, `\`, and `ldiv!(y, plan, x)`, we require `Plan` subtypes
21-
to have a `pinv::Plan` field, which caches the inverse plan, and which should be
22-
initially undefined.
23-
They should also implement `plan_inv(p)` to construct the inverse of a plan `p`.
24-
25-
Implementations only need to provide the unnormalized backwards FFT,
26-
similar to FFTW, and we do the scaling generically to get the inverse FFT.
14+
To define a new FFT implementation in your own module, you should
15+
16+
* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
17+
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
18+
inverse plan.
19+
20+
* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
21+
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
22+
23+
* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.
24+
25+
* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` method.
26+
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.
27+
28+
* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
29+
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.
30+
Implementations only need to provide the unnormalized backwards FFT, similar to FFTW, and we do the scaling generically
31+
to get the inverse FFT.
32+
33+
* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.
34+
35+
* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
36+
* `AbstractFFTs.NoProjectionStyle()`,
37+
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
38+
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
39+
40+
The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
41+
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

ext/AbstractFFTsChainRulesCoreExt.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,54 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
159159
return y, ifftshift_pullback
160160
end
161161

162+
# plans
163+
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
164+
y = P * x
165+
if Base.mightalias(y, x)
166+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
167+
end
168+
Δy = P * Δx
169+
return y, Δy
170+
end
171+
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
172+
y = P * x
173+
if Base.mightalias(y, x)
174+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
175+
end
176+
project_x = ChainRulesCore.ProjectTo(x)
177+
Pt = P'
178+
function mul_plan_pullback(ȳ)
179+
= project_x(Pt * ȳ)
180+
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
181+
end
182+
return y, mul_plan_pullback
183+
end
184+
185+
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
186+
y = P * x
187+
if Base.mightalias(y, x)
188+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
189+
end
190+
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
191+
return y, Δy
192+
end
193+
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
194+
y = P * x
195+
if Base.mightalias(y, x)
196+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
197+
end
198+
Pt = P'
199+
scale = P.scale
200+
project_x = ChainRulesCore.ProjectTo(x)
201+
project_scale = ChainRulesCore.ProjectTo(scale)
202+
function mul_scaledplan_pullback(ȳ)
203+
= ChainRulesCore.@thunk(project_x(Pt * ȳ))
204+
scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale)))
205+
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
206+
return ChainRulesCore.NoTangent(), plan_tangent, x̄
207+
end
208+
return y, mul_scaledplan_pullback
209+
end
210+
162211
end # module
212+

src/definitions.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
1212

1313
# size(p) should return the size of the input array for p
1414
size(p::Plan, d) = size(p)[d]
15+
output_size(p::Plan, d) = output_size(p)[d]
1516
ndims(p::Plan) = length(size(p))
1617
length(p::Plan) = prod(size(p))::Int
1718

@@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
255256
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
256257

257258
size(p::ScaledPlan) = size(p.p)
259+
output_size(p::ScaledPlan) = output_size(p.p)
258260

259261
fftdims(p::ScaledPlan) = fftdims(p.p)
260262

@@ -578,3 +580,80 @@ Pre-plan an optimized real-input unnormalized transform, similar to
578580
the same as for [`brfft`](@ref).
579581
"""
580582
plan_brfft
583+
584+
##############################################################################
585+
586+
struct NoProjectionStyle end
587+
struct RealProjectionStyle end
588+
struct RealInverseProjectionStyle
589+
dim::Int
590+
end
591+
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
592+
593+
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
594+
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
595+
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
596+
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
597+
598+
struct AdjointPlan{T,P<:Plan} <: Plan{T}
599+
p::P
600+
AdjointPlan{T,P}(p) where {T,P} = new(p)
601+
end
602+
603+
"""
604+
(p::Plan)'
605+
adjoint(p::Plan)
606+
607+
Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of
608+
the original plan. Note that this differs from the corresponding backwards plan in the case of real
609+
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).
610+
611+
!!! note
612+
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
613+
coverage of `Base.adjoint` in downstream implementations may be limited.
614+
"""
615+
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
616+
Base.adjoint(p::AdjointPlan) = p.p
617+
# always have AdjointPlan inside ScaledPlan.
618+
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
619+
620+
size(p::AdjointPlan) = output_size(p.p)
621+
output_size(p::AdjointPlan) = size(p.p)
622+
623+
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
624+
625+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
626+
dims = fftdims(p.p)
627+
N = normalization(T, size(p.p), dims)
628+
return (p.p \ x) / N
629+
end
630+
631+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
632+
dims = fftdims(p.p)
633+
N = normalization(T, size(p.p), dims)
634+
halfdim = first(dims)
635+
d = size(p.p, halfdim)
636+
n = output_size(p.p, halfdim)
637+
scale = reshape(
638+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
639+
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
640+
)
641+
return p.p \ (x ./ convert(typeof(x), scale))
642+
end
643+
644+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
645+
dims = fftdims(p.p)
646+
N = normalization(real(T), output_size(p.p), dims)
647+
halfdim = first(dims)
648+
n = size(p.p, halfdim)
649+
d = output_size(p.p, halfdim)
650+
scale = reshape(
651+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
652+
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
653+
)
654+
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
655+
end
656+
657+
# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
658+
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
659+
inv(p::AdjointPlan) = adjoint(inv(p.p))

0 commit comments

Comments
 (0)