Skip to content

Commit c7c7829

Browse files
authored
Merge pull request #85 from dynamicslab/expanded-derivatives
Expanded derivatives
2 parents 0800089 + 622f5c3 commit c7c7829

File tree

15 files changed

+1220
-179
lines changed

15 files changed

+1220
-179
lines changed

docs/conf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import datetime
21
import importlib
32
import pathlib
43

@@ -8,7 +7,7 @@
87

98
# no need to edit below this line
109

11-
copyright = f"{datetime.datetime.now().year}, {author}"
10+
copyright = f"2020, {author}"
1211

1312
module = importlib.import_module(project)
1413
version = release = getattr(module, "__version__")
@@ -24,6 +23,7 @@
2423
"sphinx.ext.napoleon",
2524
"sphinx.ext.mathjax",
2625
"sphinx_nbexamples",
26+
"sphinx.ext.intersphinx",
2727
]
2828

2929
apidoc_module_dir = f"../{project}"
@@ -68,6 +68,10 @@ def setup(app):
6868
pattern=".+.ipynb",
6969
)
7070

71+
intersphinx_mapping = {
72+
"derivative": ("https://derivative.readthedocs.io/en/latest/", None)
73+
}
74+
7175
# -- Extensions to the Napoleon GoogleDocstring class ---------------------
7276
# michaelgoerz.net/notes/extending-sphinx-napoleon-docstring-sections.html
7377
from sphinx.ext.napoleon.docstring import GoogleDocstring # noqa: E402

examples/1_feature_overview.ipynb

Lines changed: 158 additions & 113 deletions
Large diffs are not rendered by default.

examples/4_scikit_learn_compatibility.ipynb

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
"execution_count": 1,
2121
"metadata": {
2222
"ExecuteTime": {
23-
"end_time": "2020-07-14T21:07:35.177225Z",
24-
"start_time": "2020-07-14T21:07:33.573951Z"
23+
"end_time": "2020-08-22T22:47:55.509529Z",
24+
"start_time": "2020-08-22T22:47:54.323781Z"
2525
}
2626
},
2727
"outputs": [],
@@ -44,8 +44,8 @@
4444
"execution_count": 2,
4545
"metadata": {
4646
"ExecuteTime": {
47-
"end_time": "2020-07-14T21:07:35.285747Z",
48-
"start_time": "2020-07-14T21:07:35.184412Z"
47+
"end_time": "2020-08-22T22:47:56.708081Z",
48+
"start_time": "2020-08-22T22:47:56.618948Z"
4949
}
5050
},
5151
"outputs": [],
@@ -95,8 +95,8 @@
9595
"execution_count": 3,
9696
"metadata": {
9797
"ExecuteTime": {
98-
"end_time": "2020-07-14T21:07:44.101945Z",
99-
"start_time": "2020-07-14T21:07:35.293352Z"
98+
"end_time": "2020-08-22T22:50:47.967603Z",
99+
"start_time": "2020-08-22T22:50:40.520544Z"
100100
}
101101
},
102102
"outputs": [
@@ -135,6 +135,62 @@
135135
"search.best_estimator_.print()"
136136
]
137137
},
138+
{
139+
"cell_type": "markdown",
140+
"metadata": {},
141+
"source": [
142+
"Some extra care must be taken when working with differentiation methods from the `derivative` package (i.e. those accessed via the `SINDyDerivative` class). See the example below."
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": 4,
148+
"metadata": {
149+
"ExecuteTime": {
150+
"end_time": "2020-08-22T22:59:18.558877Z",
151+
"start_time": "2020-08-22T22:58:57.908732Z"
152+
}
153+
},
154+
"outputs": [
155+
{
156+
"name": "stdout",
157+
"output_type": "stream",
158+
"text": [
159+
"Best parameters: {'differentiation_method__kwargs': {'kind': 'spline', 's': 0.01}, 'optimizer__threshold': 0.1}\n",
160+
"x0' = -10.000 x0 + 10.000 x1\n",
161+
"x1' = 28.003 x0 + -1.001 x1 + -1.000 x0 x2\n",
162+
"x2' = -2.667 x2 + 1.000 x0 x1\n"
163+
]
164+
}
165+
],
166+
"source": [
167+
"model = ps.SINDy(\n",
168+
" t_default=dt,\n",
169+
" differentiation_method=ps.SINDyDerivative(kind='spline', s=1e-2)\n",
170+
")\n",
171+
"\n",
172+
"param_grid = {\n",
173+
" \"optimizer__threshold\": [0.001, 0.01, 0.1],\n",
174+
" \"differentiation_method__kwargs\": [\n",
175+
" {'kind': 'spline', 's': 1e-2},\n",
176+
" {'kind': 'spline', 's': 1e-1},\n",
177+
" {'kind': 'finite_difference', 'k': 1},\n",
178+
" {'kind': 'finite_difference', 'k': 2},\n",
179+
" ]\n",
180+
"}\n",
181+
"\n",
182+
"# This part is identical to what we did before\n",
183+
"search = GridSearchCV(\n",
184+
" model,\n",
185+
" param_grid,\n",
186+
" cv=TimeSeriesSplit(n_splits=5)\n",
187+
")\n",
188+
"search.fit(x_train)\n",
189+
"\n",
190+
"print(\"Best parameters:\", search.best_params_)\n",
191+
"search.best_estimator_.print()"
192+
]
193+
},
138194
{
139195
"cell_type": "markdown",
140196
"metadata": {},
@@ -145,11 +201,11 @@
145201
},
146202
{
147203
"cell_type": "code",
148-
"execution_count": 4,
204+
"execution_count": 5,
149205
"metadata": {
150206
"ExecuteTime": {
151-
"end_time": "2020-07-14T21:07:44.133790Z",
152-
"start_time": "2020-07-14T21:07:44.116536Z"
207+
"end_time": "2020-08-22T22:59:50.283609Z",
208+
"start_time": "2020-08-22T22:59:50.261397Z"
153209
}
154210
},
155211
"outputs": [],
@@ -203,11 +259,11 @@
203259
},
204260
{
205261
"cell_type": "code",
206-
"execution_count": 5,
262+
"execution_count": 6,
207263
"metadata": {
208264
"ExecuteTime": {
209-
"end_time": "2020-07-14T21:07:50.098150Z",
210-
"start_time": "2020-07-14T21:07:44.137586Z"
265+
"end_time": "2020-08-22T22:59:56.316952Z",
266+
"start_time": "2020-08-22T22:59:52.131651Z"
211267
}
212268
},
213269
"outputs": [
@@ -256,11 +312,11 @@
256312
},
257313
{
258314
"cell_type": "code",
259-
"execution_count": 6,
315+
"execution_count": 7,
260316
"metadata": {
261317
"ExecuteTime": {
262-
"end_time": "2020-07-14T21:07:50.552501Z",
263-
"start_time": "2020-07-14T21:07:50.107361Z"
318+
"end_time": "2020-08-22T22:59:59.313750Z",
319+
"start_time": "2020-08-22T22:59:58.992166Z"
264320
}
265321
},
266322
"outputs": [
@@ -284,11 +340,11 @@
284340
},
285341
{
286342
"cell_type": "code",
287-
"execution_count": 7,
343+
"execution_count": 8,
288344
"metadata": {
289345
"ExecuteTime": {
290-
"end_time": "2020-07-14T21:07:50.660764Z",
291-
"start_time": "2020-07-14T21:07:50.577603Z"
346+
"end_time": "2020-08-22T23:00:00.375038Z",
347+
"start_time": "2020-08-22T23:00:00.333557Z"
292348
}
293349
},
294350
"outputs": [

examples/5_differentation.ipynb

Lines changed: 702 additions & 0 deletions
Large diffs are not rendered by default.

pysindy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .pysindy import SINDy
1515
from .differentiation import BaseDifferentiation
1616
from .differentiation import FiniteDifference
17+
from .differentiation import SINDyDerivative
1718
from .differentiation import SmoothedFiniteDifference
1819
from .feature_library import ConcatLibrary
1920
from .feature_library import CustomLibrary

pysindy/differentiation/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from .base import BaseDifferentiation
22
from .finite_difference import FiniteDifference
3+
from .sindy_derivative import SINDyDerivative
34
from .smoothed_finite_difference import SmoothedFiniteDifference
45

5-
__all__ = ["BaseDifferentiation", "FiniteDifference", "SmoothedFiniteDifference"]
6+
7+
__all__ = [
8+
"BaseDifferentiation",
9+
"FiniteDifference",
10+
"SINDyDerivative",
11+
"SmoothedFiniteDifference",
12+
]

pysindy/differentiation/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class BaseDifferentiation(BaseEstimator):
1313
Base class for differentiation methods.
1414
1515
Simply forces differentiation methods to implement a
16-
_differentiate function.
16+
``_differentiate`` function.
1717
"""
1818

1919
def __init__(self):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
Wrapper classes for differentiation methods from the :doc:`derivative:index` package.
3+
4+
Some default values used here may differ from those used in :doc:`derivative:index`.
5+
"""
6+
from derivative import dxdt
7+
from numpy import arange
8+
from sklearn.base import BaseEstimator
9+
10+
from pysindy.utils.base import validate_input
11+
12+
13+
class SINDyDerivative(BaseEstimator):
14+
"""
15+
Wrapper class for differentiation classes from the :doc:`derivative:index` package.
16+
This class is meant to provide all the same functionality as the
17+
`dxdt <https://derivative.readthedocs.io/en/latest/api.html\
18+
#derivative.differentiation.dxdt>`_ method.
19+
20+
This class also has ``_differentiate`` and ``__call__`` methods which are
21+
used by PySINDy.
22+
23+
Parameters
24+
----------
25+
derivative_kws: dictionary, optional
26+
Keyword arguments to be passed to the
27+
`dxdt <https://derivative.readthedocs.io/en/latest/api.html\
28+
#derivative.differentiation.dxdt>`_
29+
method.
30+
31+
Notes
32+
-----
33+
See the `derivative documentation <https://derivative.readthedocs.io/en/latest/>`_
34+
for acceptable keywords.
35+
"""
36+
37+
def __init__(self, **kwargs):
38+
self.kwargs = kwargs
39+
40+
def set_params(self, **params):
41+
"""
42+
Set the parameters of this estimator.
43+
Modification of the pysindy method to allow unknown kwargs. This allows using
44+
the full range of derivative parameters that are not defined as member variables
45+
in sklearn grid search.
46+
47+
Returns
48+
-------
49+
self
50+
"""
51+
if not params:
52+
# Simple optimization to gain speed (inspect is slow)
53+
return self
54+
else:
55+
self.kwargs.update(params)
56+
57+
return self
58+
59+
def get_params(self, deep=True):
60+
"""Get parameters."""
61+
params = super().get_params(deep)
62+
63+
if isinstance(self.kwargs, dict):
64+
params.update(self.kwargs)
65+
66+
return params
67+
68+
def _differentiate(self, x, t=1):
69+
if isinstance(t, (int, float)):
70+
if t < 0:
71+
raise ValueError("t must be a positive constant or an array")
72+
t = arange(x.shape[0]) * t
73+
74+
return dxdt(x, t, axis=0, **self.kwargs)
75+
76+
def __call__(self, x, t=1):
77+
x = validate_input(x, t=t)
78+
return self._differentiate(x, t)

pysindy/feature_library/polynomial_library.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,12 @@ def transform(self, X):
215215
to_stack.append(X)
216216
for deg in range(2, self.degree + 1):
217217
Xp_next = _csr_polynomial_expansion(
218-
X.data, X.indices, X.indptr, X.shape[1], self.interaction_only, deg,
218+
X.data,
219+
X.indices,
220+
X.indptr,
221+
X.shape[1],
222+
self.interaction_only,
223+
deg,
219224
)
220225
if Xp_next is None:
221226
break

pysindy/optimizers/sr3.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,14 @@ def disable_trimming(self):
158158
self.trimming_fraction = None
159159

160160
def _update_full_coef(self, cho, x_transpose_y, coef_sparse):
161-
"""Update the unregularized weight vector
162-
"""
161+
"""Update the unregularized weight vector"""
163162
b = x_transpose_y + coef_sparse / self.nu
164163
coef_full = cho_solve(cho, b)
165164
self.iters += 1
166165
return coef_full
167166

168167
def _update_sparse_coef(self, coef_full):
169-
"""Update the regularized weight vector
170-
"""
168+
"""Update the regularized weight vector"""
171169
coef_sparse = self.prox(coef_full, self.threshold)
172170
self.history_.append(coef_sparse.T)
173171
return coef_sparse
@@ -184,8 +182,7 @@ def _trimming_grad(self, x, y, coef_full, trimming_array):
184182
return 0.5 * np.sum(R2, axis=1)
185183

186184
def _convergence_criterion(self):
187-
"""Calculate the convergence criterion for the optimization
188-
"""
185+
"""Calculate the convergence criterion for the optimization"""
189186
this_coef = self.history_[-1]
190187
if len(self.history_) > 1:
191188
last_coef = self.history_[-2]

pysindy/optimizers/stlsq.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,25 +104,22 @@ def __init__(
104104
self.ridge_kw = ridge_kw
105105

106106
def _sparse_coefficients(self, dim, ind, coef, threshold):
107-
"""Perform thresholding of the weight vector(s)
108-
"""
107+
"""Perform thresholding of the weight vector(s)"""
109108
c = np.zeros(dim)
110109
c[ind] = coef
111110
big_ind = np.abs(c) >= threshold
112111
c[~big_ind] = 0
113112
return c, big_ind
114113

115114
def _regress(self, x, y):
116-
"""Perform the ridge regression
117-
"""
115+
"""Perform the ridge regression"""
118116
kw = self.ridge_kw or {}
119117
coef = ridge_regression(x, y, self.alpha, **kw)
120118
self.iters += 1
121119
return coef
122120

123121
def _no_change(self):
124-
"""Check if the coefficient mask has changed after thresholding
125-
"""
122+
"""Check if the coefficient mask has changed after thresholding"""
126123
this_coef = self.history_[-1].flatten()
127124
if len(self.history_) > 1:
128125
last_coef = self.history_[-2].flatten()

0 commit comments

Comments
 (0)