Skip to content

Commit c399c19

Browse files
WIP: Split off Nonlinear Solvers and Differentiation
These subpackages will be dependencies of some of the other subpackages. The reason to excise this from the core is to down dependencies required for explicit RK methods. Todo: - [ ] DAE initialization needs a good error message for if nonlinear solvers haven't been loaded but you need to run the initialization (because of a initializeprob) - [ ] Dependencies need to be setup in the Project.tomls of the sublibraries
1 parent a386684 commit c399c19

22 files changed

+883
-823
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name = "OrdinaryDiffEqDifferentiation"
2+
uuid = "4302a76b-040a-498a-8c04-15b101fed76b"
3+
authors = ["Chris Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
4+
version = "1.0.0"
5+
6+
[deps]
7+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8+
9+
[compat]
10+
julia = "1.10"
11+
12+
[extras]
13+
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
18+
[targets]
19+
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test"]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
module OrdinaryDiffEqDifferentiation
2+
3+
import ADTypes: AutoFiniteDiff, AutoForwardDiff
4+
5+
import SparseDiffTools: SparseDiffTools, matrix_colors, forwarddiff_color_jacobian!,
6+
forwarddiff_color_jacobian, ForwardColorJacCache,
7+
default_chunk_size, getsize, JacVec
8+
9+
import ForwardDiff, FiniteDiff
10+
import ForwardDiff.Dual
11+
import LinearSolve
12+
13+
using DiffEqBase: TimeGradientWrapper,
14+
UJacobianWrapper, TimeDerivativeWrapper,
15+
UDerivativeWrapper
16+
17+
@static if isdefined(DiffEqBase, :OrdinaryDiffEqTag)
18+
import DiffEqBase: OrdinaryDiffEqTag
19+
else
20+
struct OrdinaryDiffEqTag end
21+
end
22+
23+
include("alg_utils.jl")
24+
include("ilnsolve_utils.jl")
25+
include("derivative_utils.jl")
26+
include("derivative_wrappers.jl")
27+
28+
end
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
2+
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
3+
error("This algorithm does not have an autodifferentiation option defined.")
4+
end
5+
_alg_autodiff(::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
6+
_alg_autodiff(::DAEAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
7+
_alg_autodiff(::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
8+
_alg_autodiff(alg::CompositeAlgorithm) = _alg_autodiff(alg.algs[end])
9+
function _alg_autodiff(::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
10+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}
11+
}) where {
12+
CS, AD
13+
}
14+
Val{AD}()
15+
end
16+
17+
function alg_autodiff(alg)
18+
autodiff = _alg_autodiff(alg)
19+
if autodiff == Val(false)
20+
return AutoFiniteDiff()
21+
elseif autodiff == Val(true)
22+
return AutoForwardDiff()
23+
else
24+
return _unwrap_val(autodiff)
25+
end
26+
end
27+
28+
Base.@pure function determine_chunksize(u, alg::DiffEqBase.DEAlgorithm)
29+
determine_chunksize(u, get_chunksize(alg))
30+
end
31+
Base.@pure function determine_chunksize(u, CS)
32+
if CS != 0
33+
return CS
34+
else
35+
return ForwardDiff.pickchunksize(length(u))
36+
end
37+
end
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
issuccess_W(W::LinearAlgebra.Factorization) = LinearAlgebra.issuccess(W)
2+
issuccess_W(W::Number) = !iszero(W)
3+
issuccess_W(::Any) = true
4+
5+
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
6+
du = nothing, u = nothing, p = nothing, t = nothing,
7+
weight = nothing, solverdata = nothing,
8+
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
9+
A !== nothing && (linsolve.A = A)
10+
b !== nothing && (linsolve.b = b)
11+
linu !== nothing && (linsolve.u = linu)
12+
13+
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
14+
linsolve.Pl
15+
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
16+
linsolve.Pr
17+
18+
_alg = unwrap_alg(integrator, true)
19+
20+
_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
21+
solverdata)
22+
if (_Pl !== nothing || _Pr !== nothing)
23+
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
24+
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
25+
linsolve.Pl = __Pl
26+
linsolve.Pr = __Pr
27+
end
28+
29+
linres = solve!(linsolve; reltol)
30+
31+
# TODO: this ignores the add of the `f` count for add_steps!
32+
if integrator isa SciMLBase.DEIntegrator && _alg.linsolve !== nothing &&
33+
!LinearSolve.needs_concrete_A(_alg.linsolve) &&
34+
linsolve.A isa WOperator && linsolve.A.J isa AbstractSciMLOperator
35+
if alg_autodiff(_alg) isa AutoForwardDiff
36+
integrator.stats.nf += linres.iters
37+
elseif alg_autodiff(_alg) isa AutoFiniteDiff
38+
integrator.stats.nf += 2 * linres.iters
39+
else
40+
error("$alg_autodiff not yet supported in dolinsolve function")
41+
end
42+
end
43+
44+
return linres
45+
end
46+
47+
function wrapprecs(_Pl::Nothing, _Pr::Nothing, weight, u)
48+
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
49+
Pr = Diagonal(_vec(weight))
50+
Pl, Pr
51+
end
52+
53+
function wrapprecs(_Pl, _Pr, weight, u)
54+
Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pl
55+
Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pr
56+
Pl, Pr
57+
end
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name = "OrdinaryDiffEqNonlinearSolve"
2+
uuid = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
3+
authors = ["Chris Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
4+
version = "1.0.0"
5+
6+
[deps]
7+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8+
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
9+
10+
[compat]
11+
julia = "1.10"
12+
13+
[extras]
14+
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
15+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
17+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18+
19+
[targets]
20+
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module OrdinaryDiffEqNonlinearSolve
2+
3+
import ADTypes: AutoFiniteDiff, AutoForwardDiff
4+
5+
import SciMLBase
6+
import DiffEqBase
7+
import PreallocationTools
8+
using SimpleNonlinearSolve: SimpleTrustRegion, SimpleGaussNewton
9+
using NonlinearSolve: FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
10+
11+
import OrdinaryDiffEq: resize_nlsolver!, _initialize_dae!
12+
import OrdinaryDiffEqDifferentiation: update_W!, isnewton
13+
14+
include("type.jl")
15+
include("utils.jl")
16+
include("nlsolve.jl")
17+
include("functional.jl")
18+
include("newton.jl")
19+
include("initialize_dae.jl")
20+
21+
end

0 commit comments

Comments
 (0)