Skip to content

Commit 7da339a

Browse files
committed
Add minor changes and resolve conflict
2 parents 1954192 + 231a977 commit 7da339a

File tree

153 files changed

+3251
-1583
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

153 files changed

+3251
-1583
lines changed

.github/workflows/mypy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
shell: bash -leo pipefail {0}
1414
steps:
1515
- uses: actions/checkout@v4
16-
- uses: mamba-org/setup-micromamba@v1
16+
- uses: mamba-org/setup-micromamba@v2
1717
with:
1818
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
1919
environment-file: environment.yml

.github/workflows/pypi.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
fetch-depth: 0
6363

6464
- name: Build wheels
65-
uses: pypa/cibuildwheel@v2.21.2
65+
uses: pypa/cibuildwheel@v2.22.0
6666

6767
- uses: actions/upload-artifact@v4
6868
with:
@@ -78,7 +78,7 @@ jobs:
7878
fetch-depth: 0
7979

8080
- name: Set up Python
81-
uses: actions/setup-python@v4
81+
uses: actions/setup-python@v5
8282
with:
8383
python-version: '3.11'
8484

@@ -145,7 +145,7 @@ jobs:
145145
name: universal_wheel
146146
path: dist
147147

148-
- uses: pypa/gh-action-pypi-publish@v1.10.3
148+
- uses: pypa/gh-action-pypi-publish@v1.12.2
149149
with:
150150
user: __token__
151151
password: ${{ secrets.pypi_password }}

.github/workflows/test.yml

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,17 @@ jobs:
6161
python-version: ${{ matrix.python-version }}
6262
- uses: pre-commit/[email protected]
6363

64-
test_ubuntu:
65-
name: "Test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
64+
test:
65+
name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
6666
needs:
6767
- changes
6868
- style
69-
runs-on: ubuntu-latest
69+
runs-on: ${{ matrix.os }}
7070
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
7171
strategy:
7272
fail-fast: false
7373
matrix:
74+
os: ["ubuntu-latest"]
7475
python-version: ["3.10", "3.12"]
7576
fast-compile: [0, 1]
7677
float32: [0, 1]
@@ -103,36 +104,50 @@ jobs:
103104
fast-compile: 1
104105
include:
105106
- install-numba: 1
107+
os: "ubuntu-latest"
106108
python-version: "3.10"
107109
fast-compile: 0
108110
float32: 0
109111
part: "tests/link/numba"
110112
- install-numba: 1
113+
os: "ubuntu-latest"
111114
python-version: "3.12"
112115
fast-compile: 0
113116
float32: 0
114117
part: "tests/link/numba"
115118
- install-jax: 1
119+
os: "ubuntu-latest"
116120
python-version: "3.10"
117121
fast-compile: 0
118122
float32: 0
119123
part: "tests/link/jax"
120124
- install-jax: 1
125+
os: "ubuntu-latest"
121126
python-version: "3.12"
122127
fast-compile: 0
123128
float32: 0
124129
part: "tests/link/jax"
125130
- install-torch: 1
131+
os: "ubuntu-latest"
126132
python-version: "3.10"
127133
fast-compile: 0
128134
float32: 0
129135
part: "tests/link/pytorch"
136+
- os: macos-15
137+
python-version: "3.12"
138+
fast-compile: 0
139+
float32: 0
140+
install-numba: 0
141+
install-jax: 0
142+
install-torch: 0
143+
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
144+
130145
steps:
131146
- uses: actions/checkout@v4
132147
with:
133148
fetch-depth: 0
134149
- name: Set up Python ${{ matrix.python-version }}
135-
uses: mamba-org/setup-micromamba@v1
150+
uses: mamba-org/setup-micromamba@v2
136151
with:
137152
environment-name: pytensor-test
138153
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
@@ -146,15 +161,19 @@ jobs:
146161
MATRIX_CONTEXT: ${{ toJson(matrix) }}
147162
run: |
148163
echo $MATRIX_CONTEXT
149-
export MATRIX_ID=`echo $MATRIX_CONTEXT | md5sum | cut -c 1-32`
164+
export MATRIX_ID=`echo $MATRIX_CONTEXT | sha256sum | cut -c 1-32`
150165
echo $MATRIX_ID
151166
echo "id=$MATRIX_ID" >> $GITHUB_OUTPUT
152167
153168
- name: Install dependencies
154169
shell: micromamba-shell {0}
155170
run: |
156171
157-
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
172+
if [[ $OS == "macos-15" ]]; then
173+
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" numpy scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate;
174+
else
175+
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
176+
fi
158177
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
159178
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
160179
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
@@ -163,12 +182,17 @@ jobs:
163182
pip install -e ./
164183
micromamba list && pip freeze
165184
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
166-
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
185+
if [[ $OS == "macos-15" ]]; then
186+
python -c 'import pytensor; assert pytensor.config.blas__ldflags.startswith("-framework Accelerate"), "Blas flags are not set to MacOS Accelerate"';
187+
else
188+
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"';
189+
fi
167190
env:
168191
PYTHON_VERSION: ${{ matrix.python-version }}
169192
INSTALL_NUMBA: ${{ matrix.install-numba }}
170193
INSTALL_JAX: ${{ matrix.install-jax }}
171194
INSTALL_TORCH: ${{ matrix.install-torch}}
195+
OS: ${{ matrix.os}}
172196

173197
- name: Run tests
174198
shell: micromamba-shell {0}
@@ -206,7 +230,7 @@ jobs:
206230
with:
207231
fetch-depth: 0
208232
- name: Set up Python 3.10
209-
uses: mamba-org/setup-micromamba@v1
233+
uses: mamba-org/setup-micromamba@v2
210234
with:
211235
environment-name: pytensor-test
212236
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
@@ -249,10 +273,10 @@ jobs:
249273
if: ${{ always() }}
250274
runs-on: ubuntu-latest
251275
name: "All tests"
252-
needs: [changes, style, test_ubuntu]
276+
needs: [changes, style, test]
253277
steps:
254278
- name: Check build matrix status
255-
if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test_ubuntu.result != 'success') }}
279+
if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test.result != 'success') }}
256280
run: exit 1
257281

258282
upload-coverage:
@@ -280,7 +304,7 @@ jobs:
280304
merge-multiple: true
281305

282306
- name: Upload coverage to Codecov
283-
uses: codecov/codecov-action@v4
307+
uses: codecov/codecov-action@v5
284308
with:
285309
directory: ./coverage/
286310
fail_ci_if_error: true

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
- id: sphinx-lint
2828
args: ["."]
2929
- repo: https://github.com/astral-sh/ruff-pre-commit
30-
rev: v0.7.1
30+
rev: v0.7.3
3131
hooks:
3232
- id: ruff
3333
args: ["--fix", "--output-format=full"]

doc/library/tensor/random/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ sophisticated `Op`\s like `Scan`, which makes it a user-friendly random variable
3030
interface in PyTensor.
3131

3232
For an example of how to use random numbers, see :ref:`Using Random Numbers <using_random_numbers>`.
33+
For a technical explanation of how PyTensor implements random variables see :ref:`prng`.
3334

3435

3536
.. class:: RandomStream()

doc/tutorial/examples.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,9 @@ hold here as well. PyTensor's random objects are defined and implemented in
357357
:ref:`RandomStream<libdoc_tensor_random_utils>` and, at a lower level,
358358
in :ref:`RandomVariable<libdoc_tensor_random_basic>`.
359359

360+
For a more technical explanation of how PyTensor implements random variables see :ref:`prng`.
361+
362+
360363
Brief Example
361364
-------------
362365

doc/tutorial/prng.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ Pseudo random number generation in PyTensor
55
===========================================
66

77
PyTensor has native support for `pseudo random number generation (PRNG) <https://en.wikipedia.org/wiki/Pseudorandom_number_generator>`_.
8-
This document describes how PRNGs are implemented in PyTensor, via the RandomVariable Operator.
8+
9+
This document describes the details of how PRNGs are implemented in PyTensor, via the RandomVariable Operator.
10+
For a more applied example see :ref:`using_random_numbers`
911

1012
We also discuss how initial seeding and seeding updates are implemented, and some harder cases such as using RandomVariables inside Scan, or with other backends like JAX.
1113

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,12 @@ exclude = ["doc/", "pytensor/_version.py"]
129129
docstring-code-format = true
130130

131131
[tool.ruff.lint]
132-
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"]
132+
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"]
133133
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
134+
unfixable = [
135+
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
136+
"B905",
137+
]
134138

135139

136140
[tool.ruff.lint.isort]

pytensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def get_underlying_scalar_constant(v):
164164
from pytensor.scan import checkpoints
165165
from pytensor.scan.basic import scan
166166
from pytensor.scan.views import foldl, foldr, map, reduce
167+
from pytensor.compile.builders import OpFromGraph
167168

168169
# isort: on
169170

pytensor/compile/builders.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ def infer_shape(outs, inputs, input_shapes):
4343
# TODO: ShapeFeature should live elsewhere
4444
from pytensor.tensor.rewriting.shape import ShapeFeature
4545

46-
for inp, inp_shp in zip(inputs, input_shapes):
46+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
4747
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
4848
assert len(inp_shp) == inp.type.ndim
4949

5050
shape_feature = ShapeFeature()
5151
shape_feature.on_attach(FunctionGraph([], []))
5252

5353
# Initialize shape_of with the input shapes
54-
for inp, inp_shp in zip(inputs, input_shapes):
54+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
5555
shape_feature.set_shape(inp, inp_shp)
5656

5757
def local_traverse(out):
@@ -108,7 +108,9 @@ def construct_nominal_fgraph(
108108

109109
replacements = dict(
110110
zip(
111-
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
111+
inputs + implicit_shared_inputs,
112+
dummy_inputs + dummy_implicit_shared_inputs,
113+
strict=True,
112114
)
113115
)
114116

@@ -138,7 +140,7 @@ def construct_nominal_fgraph(
138140
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
139141
)
140142

141-
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
143+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True))
142144

143145
for i, inp in enumerate(fgraph.inputs):
144146
nom_inp = nominal_local_inputs[i]
@@ -562,7 +564,9 @@ def lop_overrides(inps, grads):
562564
# compute non-overriding downsteam grads from upstreams grads
563565
# it's normal some input may be disconnected, thus the 'ignore'
564566
wrt = [
565-
lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
567+
lin
568+
for lin, gov in zip(inner_inputs, custom_input_grads, strict=True)
569+
if gov is None
566570
]
567571
default_input_grads = fn_grad(wrt=wrt) if wrt else []
568572
input_grads = self._combine_list_overrides(
@@ -653,7 +657,7 @@ def _build_and_cache_rop_op(self):
653657
f = [
654658
output
655659
for output, custom_output_grad in zip(
656-
inner_outputs, custom_output_grads
660+
inner_outputs, custom_output_grads, strict=True
657661
)
658662
if custom_output_grad is None
659663
]
@@ -733,18 +737,24 @@ def make_node(self, *inputs):
733737

734738
non_shared_inputs = [
735739
inp_t.filter_variable(inp)
736-
for inp, inp_t in zip(non_shared_inputs, self.input_types)
740+
for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True)
737741
]
738742

739743
new_shared_inputs = inputs[num_expected_inps:]
740-
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
744+
inner_and_input_shareds = list(
745+
zip(self.shared_inputs, new_shared_inputs, strict=True)
746+
)
741747

742748
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
743749
# The shared variables are not equal to the original shared
744750
# variables, so we construct a new `Op` that uses the new shared
745751
# variables instead.
746752
replace = dict(
747-
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
753+
zip(
754+
self.inner_inputs[num_expected_inps:],
755+
new_shared_inputs,
756+
strict=True,
757+
)
748758
)
749759

750760
# If the new shared variables are inconsistent with the inner-graph,
@@ -811,7 +821,7 @@ def infer_shape(self, fgraph, node, shapes):
811821
# each shape call. PyTensor optimizer will clean this up later, but this
812822
# will make extra work for the optimizer.
813823

814-
repl = dict(zip(self.inner_inputs, node.inputs))
824+
repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
815825
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
816826
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
817827
ret = []
@@ -853,5 +863,6 @@ def clone(self):
853863
def perform(self, node, inputs, outputs):
854864
variables = self.fn(*inputs)
855865
assert len(variables) == len(outputs)
856-
for output, variable in zip(outputs, variables):
866+
# strict=False because asserted above
867+
for output, variable in zip(outputs, variables, strict=False):
857868
output[0] = variable

0 commit comments

Comments
 (0)