diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index b92ca1c1a..fbf64ca8c 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -11,6 +11,6 @@ jobs: env: SKIP: no-commit-to-branch steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 - uses: pre-commit/action@v2.0.0 diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 682141a75..20a3e2604 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -12,10 +12,10 @@ jobs: name: build source distribution runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.10" - name: Build the sdist and the wheel @@ -28,7 +28,7 @@ jobs: cd test-sdist python -m venv venv-sdist venv-sdist/bin/python -m pip install numpy - venv-sdist/bin/python -m pip install ../dist/pymc-experimental*.tar.gz + venv-sdist/bin/python -m pip install ../dist/pymc_experimental*.tar.gz echo "Checking import and version number (on release)" venv-sdist/bin/python -c "import pymc_experimental as pmx; assert pmx.__version__ == '${{ github.ref_name }}'[1:] if '${{ github.ref_type }}' == 'tag' else True; print(pmx.__version__)" cd .. @@ -61,7 +61,7 @@ jobs: user: __token__ password: ${{ secrets.TEST_PYPI_API_TOKEN }} repository_url: https://test.pypi.org/legacy/ - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.10" - name: Test pip install from test.pypi diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1ec9e6697..1379b99c2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.9"] + python-version: ["3.10"] test-subset: - pymc_experimental/tests fail-fast: false @@ -28,49 +28,23 @@ jobs: PYTENSOR_FLAGS: gcc__cxxflags='-march=native' defaults: run: - shell: bash -l {0} + shell: bash -leo pipefail {0} steps: - - uses: actions/checkout@v2 - - name: Cache conda - uses: actions/cache@v1 - env: - # Increase this value to reset cache if environment-test.yml has not changed - CACHE_NUMBER: 0 + - uses: actions/checkout@v4 + - uses: mamba-org/setup-micromamba@v1 with: - path: ~/conda_pkgs_dir - key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{ - hashFiles('conda-envs/environment-test.yml') }} - - name: Cache multiple paths - uses: actions/cache@v2 - env: - # Increase this value to reset cache if requirements.txt has not changed - CACHE_NUMBER: 0 - with: - path: | - ~/.cache/pip - $RUNNER_TOOL_CACHE/Python/* - ~\AppData\Local\pip\Cache - key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{ - hashFiles('requirements.txt') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - mamba-version: "*" - activate-environment: pymc-experimental-test - channel-priority: strict environment-file: conda-envs/environment-test.yml - python-version: ${{matrix.python-version}} - use-mamba: true - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + create-args: >- + python=${{matrix.python-version}} + environment-name: pymc-experimental-test + init-shell: bash + cache-environment: true - name: Install pymc-experimental run: | - conda activate pymc-experimental-test pip install -e . python --version - name: Run tests run: | - conda activate pymc-experimental-test python -m pytest -vv --cov=pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET - name: Upload coverage to Codecov uses: codecov/codecov-action@v2 @@ -82,7 +56,7 @@ jobs: strategy: matrix: os: [windows-latest] - python-version: ["3.11"] + python-version: ["3.12"] test-subset: - pymc_experimental/tests fail-fast: false @@ -92,51 +66,25 @@ jobs: PYTENSOR_FLAGS: gcc__cxxflags='-march=core2' defaults: run: - shell: cmd + shell: cmd /C call {0} steps: - - uses: actions/checkout@v2 - - name: Cache conda - uses: actions/cache@v1 - env: - # Increase this value to reset cache if conda-envs/windows-environment-test.yml has not changed - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{ - hashFiles('conda-envs/windows-environment-test.yml') }} - - name: Cache multiple paths - uses: actions/cache@v2 - env: - # Increase this value to reset cache if requirements.txt has not changed - CACHE_NUMBER: 0 - with: - path: | - ~/.cache/pip - $RUNNER_TOOL_CACHE/Python/* - ~\AppData\Local\pip\Cache - key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{ - hashFiles('requirements.txt') }} - - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/checkout@v4 + - uses: mamba-org/setup-micromamba@v1 with: - miniforge-variant: Mambaforge - miniforge-version: latest - mamba-version: "*" - activate-environment: pymc-experimental-test - channel-priority: strict environment-file: conda-envs/windows-environment-test.yml - python-version: ${{matrix.python-version}} - use-mamba: true - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + create-args: >- + python=${{matrix.python-version}} + environment-name: pymc-experimental-test + init-shell: cmd.exe + cache-environment: true - name: Install pymc-experimental run: | - conda activate pymc-experimental-test pip install -e . python --version - name: Run tests # This job uses a cmd shell, therefore the environment variable syntax is different! # The ">-" in the next line replaces newlines with spaces (see https://stackoverflow.com/a/66809682). run: >- - conda activate pymc-experimental-test && python -m pytest -vv --cov=pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET% - name: Upload coverage to Codecov uses: codecov/codecov-action@v2 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 010b0bdd0..4002dca20 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,7 +5,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.9" + python: "3.10" python: install: diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 318fb130b..0fb71f52c 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -4,13 +4,12 @@ channels: - defaults dependencies: - pip - - pytest-cov>=2.5 - pytest>=3.0 - dask - xhistogram - statsmodels - pip: - - pymc>=5.11.0 # CI was failing to resolve + - pymc>=5.13.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 0563959b8..0fb71f52c 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.11.0 # CI was failing to resolve + - pymc>=5.13.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/pymc_experimental/distributions/discrete.py b/pymc_experimental/distributions/discrete.py index 4c874bc73..368142cdf 100644 --- a/pymc_experimental/distributions/discrete.py +++ b/pymc_experimental/distributions/discrete.py @@ -20,6 +20,15 @@ from pytensor.tensor.random.op import RandomVariable +def log1mexp(x): + cond = x < np.log(0.5) + return np.piecewise( + x, + [cond, ~cond], + [lambda x: np.log1p(-np.exp(x)), lambda x: np.log(-np.expm1(x))], + ) + + class GeneralizedPoissonRV(RandomVariable): name = "generalized_poisson" ndim_supp = 0 @@ -74,7 +83,7 @@ def _inverse_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask): log1p_lam_m_C = np.where( pos_lam, np.log1p(np.exp(abs_log_lam - log_c)), - pm.math.log1mexp_numpy(abs_log_lam - log_c, negative_input=True), + log1mexp(abs_log_lam - log_c), ) log_p = log_c + log1p_lam_m_C * (x_ - 1) + log_p - np.log(x_) - lam log_s = np.logaddexp(log_s, log_p) diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 1ef0e78d1..f6214a7bd 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -418,9 +418,7 @@ def R2D2M2CP( *broadcast_dims, dim = dims input_sigma = pt.as_tensor(input_sigma) output_sigma = pt.as_tensor(output_sigma) - with pm.Model(name) as model: - if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims): - raise ValueError(f"{dims!r} should be constant length immutable dims") + with pm.Model(name): if r2_std is not None: r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims) phi = _phi( diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index 574f2abd0..91da141ac 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -202,13 +202,13 @@ def transition(*args): discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1) discrete_mc_op = DiscreteMarkovChainRV( - inputs=[P_, steps_, init_dist_], + inputs=[P_, steps_, init_dist_, state_rng], outputs=[state_next_rng, discrete_mc_], ndim_supp=1, n_lags=n_lags, ) - discrete_mc = discrete_mc_op(P, steps, init_dist) + discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng) return discrete_mc diff --git a/pymc_experimental/linearmodel.py b/pymc_experimental/linearmodel.py index f0e488762..0c4237dab 100644 --- a/pymc_experimental/linearmodel.py +++ b/pymc_experimental/linearmodel.py @@ -69,8 +69,8 @@ def build_model(self, X: pd.DataFrame, y: pd.Series): # Data array size can change but number of dimensions must stay the same. with pm.Model() as self.model: - x = pm.MutableData("x", np.zeros((1,)), dims="observation") - y_data = pm.MutableData("y_data", np.zeros((1,)), dims="observation") + x = pm.Data("x", np.zeros((1,)), dims="observation") + y_data = pm.Data("y_data", np.zeros((1,)), dims="observation") # priors intercept = pm.Normal( diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 8a832ef8a..ed9490511 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -6,7 +6,7 @@ import pytensor.tensor as pt from arviz import dict_to_dataset from pymc import SymbolicRandomVariable -from pymc.backends.arviz import coords_and_dims_for_inferencedata +from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.abstract import _logprob @@ -14,7 +14,7 @@ from pymc.logprob.transforms import IntervalTransform from pymc.model import Model from pymc.pytensorf import compile_pymc, constant_fold, inputvars -from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict +from pymc.util import _get_seeds_per_chain, treedict from pytensor import Mode, scan from pytensor.compile import SharedVariable from pytensor.compile.builders import OpFromGraph @@ -410,7 +410,7 @@ def transform_input(inputs): marginalized_rv.type, dependent_logps ) - rv_shape = constant_fold(tuple(marginalized_rv.shape)) + rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) rv_domain_tensor = pt.moveaxis( pt.full( @@ -579,6 +579,15 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): return True +from pytensor.graph.basic import graph_inputs + + +def collect_shared_vars(outputs, blockers): + return [ + inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable) + ] + + def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): # TODO: This should eventually be integrated in a more general routine that can # identify other types of supported marginalization, of which finite discrete @@ -621,14 +630,8 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] outputs = rvs_to_marginalize - # Clone replace inner RV rng inputs so that we can be sure of the update order - # replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()} - # Clone replace outter RV inputs, so that their shared RNGs don't make it into - # the inner graph of the marginalized RVs - # FIXME: This shouldn't be needed! - replace_inputs = {} - replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs}) - cloned_outputs = clone_replace(outputs, replace=replace_inputs) + # We are strict about shared variables in SymbolicRandomVariables + inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs) if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): marginalize_constructor = DiscreteMarginalMarkovChainRV @@ -636,12 +639,12 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs marginalize_constructor = FiniteDiscreteMarginalRV marginalization_op = marginalize_constructor( - inputs=list(replace_inputs.values()), - outputs=cloned_outputs, + inputs=inputs, + outputs=outputs, ndim_supp=ndim_supp, ) - marginalized_rvs = marginalization_op(*replace_inputs.keys()) + marginalized_rvs = marginalization_op(*inputs) fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) return rvs_to_marginalize, marginalized_rvs diff --git a/pymc_experimental/statespace/core/statespace.py b/pymc_experimental/statespace/core/statespace.py index 29c91f6b7..6350a1781 100644 --- a/pymc_experimental/statespace/core/statespace.py +++ b/pymc_experimental/statespace/core/statespace.py @@ -291,7 +291,7 @@ def _print_data_requirements(self) -> None: out = out.rstrip() _log.info( - "The following MutableData variables should be assigned to the model inside a PyMC " + "The following Data variables should be assigned to the model inside a PyMC " f"model block: \n" f"{out}" ) @@ -366,7 +366,7 @@ def param_info(self) -> dict[str, dict[str, Any]]: @property def data_info(self) -> dict[str, dict[str, Any]]: """ - Information about MutableData variables that need to be declared in the PyMC model block. + Information about Data variables that need to be declared in the PyMC model block. Returns a dictionary of data_name: dictionary of property-name:property description pairs. The return value is used by the ``_print_data_requirements`` method, to print a message telling users how to define the necessary @@ -877,7 +877,7 @@ def build_statespace_graph( or a Pytensor tensor variable. register_data : bool, optional, default=True - If True, the observed data will be registered with PyMC as a pm.MutableData variable. In addition, + If True, the observed data will be registered with PyMC as a pm.Data variable. In addition, a "time" dim will be created an added to the model's coords. mode : Optional[str], optional, default=None diff --git a/pymc_experimental/statespace/filters/distributions.py b/pymc_experimental/statespace/filters/distributions.py index 8a75e1032..6298af6a4 100644 --- a/pymc_experimental/statespace/filters/distributions.py +++ b/pymc_experimental/statespace/filters/distributions.py @@ -193,12 +193,12 @@ def step_fn(*args): (ss_rng,) = tuple(updates.values()) linear_gaussian_ss_op = LinearGaussianStateSpaceRV( - inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps_], + inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps_, rng], outputs=[ss_rng, statespace_], ndim_supp=1, ) - linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps) + linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps, rng) return linear_gaussian_ss @@ -354,10 +354,10 @@ def step(mu, cov, rng): (seq_mvn_rng,) = tuple(updates.values()) mvn_seq_op = KalmanFilterRV( - inputs=[mus_, covs_, logp_, steps_], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2 + inputs=[mus_, covs_, logp_, steps_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2 ) - mvn_seq = mvn_seq_op(mus, covs, logp, steps) + mvn_seq = mvn_seq_op(mus, covs, logp, steps, rng) return mvn_seq diff --git a/pymc_experimental/statespace/utils/data_tools.py b/pymc_experimental/statespace/utils/data_tools.py index 0cbd3a859..29c03e69a 100644 --- a/pymc_experimental/statespace/utils/data_tools.py +++ b/pymc_experimental/statespace/utils/data_tools.py @@ -112,8 +112,8 @@ def add_data_to_active_model(values, index): if OBS_STATE_DIM in pymc_mod.coords: data_dims = [TIME_DIM, OBS_STATE_DIM] - pymc_mod.add_coord(TIME_DIM, index, mutable=True) - data = pm.ConstantData("data", values, dims=data_dims) + pymc_mod.add_coord(TIME_DIM, index) + data = pm.Data("data", values, dims=data_dims) return data diff --git a/pymc_experimental/tests/distributions/test_discrete.py b/pymc_experimental/tests/distributions/test_discrete.py index 942802fc6..60885908f 100644 --- a/pymc_experimental/tests/distributions/test_discrete.py +++ b/pymc_experimental/tests/distributions/test_discrete.py @@ -197,9 +197,13 @@ def test_logp(self): class TestSkellam: def test_logp(self): - check_logp( - Skellam, - I, - {"mu1": Rplus, "mu2": Rplus}, - lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2), - ) + # Scipy Skellam underflows to -inf earlier than PyMC + Rplus_small = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 10, np.inf]) + # Suppress warnings coming from Scipy logpmf implementation + with np.errstate(divide="ignore", invalid="ignore"): + check_logp( + Skellam, + I, + {"mu1": Rplus_small, "mu2": Rplus_small}, + lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2), + ) diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index fcf5cfec1..f0ecfa98e 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -302,27 +302,3 @@ def test_zero_length_rvs_not_created(self, model: pm.Model): "b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a" ) assert not model.free_RVs, model.free_RVs - - def test_immutable_dims(self, model: pm.Model): - model.add_coord("a", range(2), mutable=True) - model.add_coord("b", range(2), mutable=False) - with pytest.raises(ValueError, match="should be constant length immutable dims"): - pmx.distributions.R2D2M2CP( - "beta0", - 1, - [1, 1], - dims="a", - r2=0.8, - positive_probs=[0.5, 1], - positive_probs_std=[0.3, 0], - ) - with pytest.raises(ValueError, match="should be constant length immutable dims"): - pmx.distributions.R2D2M2CP( - "beta0", - 1, - [1, 1], - dims=("a", "b"), - r2=0.8, - positive_probs=[0.5, 1], - positive_probs_std=[0.3, 0], - ) diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index f9a0a344b..610f3b47b 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -42,7 +42,9 @@ def disaster_model(): early_rate = pm.Exponential("early_rate", 1.0, initval=3) late_rate = pm.Exponential("late_rate", 1.0, initval=1) rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) - with pytest.warns(ImputationWarning): + with pytest.warns(ImputationWarning), pytest.warns( + RuntimeWarning, match="invalid value encountered in cast" + ): disasters = pm.Poisson("disasters", rate, observed=disaster_data) return disaster_model, years @@ -60,9 +62,9 @@ def test_marginalized_bernoulli_logp(): [idx, y], ndim_supp=0, n_updates=0, - )( - mu - )[0].owner + # Ignore the fact we didn't specify shared RNG input/outputs for idx,y + strict=False, + )(mu)[0].owner y_vv = y.clone() (logp,) = _logprob( @@ -608,8 +610,8 @@ def test_is_conditional_dependent_static_shape(): def test_data_container(): """Test that MarginalModel can handle Data containers.""" - with MarginalModel(coords_mutable={"obs": [0]}) as marginal_m: - x = pm.MutableData("x", 2.5) + with MarginalModel(coords={"obs": [0]}) as marginal_m: + x = pm.Data("x", 2.5) idx = pm.Bernoulli("idx", p=0.7, dims="obs") y = pm.Normal("y", idx * x, dims="obs") @@ -617,8 +619,8 @@ def test_data_container(): logp_fn = marginal_m.compile_logp() - with pm.Model(coords_mutable={"obs": [0]}) as m_ref: - x = pm.MutableData("x", 2.5) + with pm.Model(coords={"obs": [0]}) as m_ref: + x = pm.Data("x", 2.5) y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs") ref_logp_fn = m_ref.compile_logp() @@ -758,3 +760,19 @@ def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} np.testing.assert_allclose(logp_fn(test_point), expected_logp) + + +def test_mutable_indexing_jax_backend(): + pytest.importorskip("jax") + from pymc.sampling.jax import get_jaxified_logp + + with MarginalModel() as model: + data = pm.Data(f"data", np.zeros(10)) + + cat_effect = pm.Normal("cat_effect", sigma=1, shape=5) + cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5)) + + is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10) + pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data) + model.marginalize(["is_outlier"]) + get_jaxified_logp(model) diff --git a/pymc_experimental/tests/model/transforms/test_autoreparam.py b/pymc_experimental/tests/model/transforms/test_autoreparam.py index 9749894ed..b2ea245ae 100644 --- a/pymc_experimental/tests/model/transforms/test_autoreparam.py +++ b/pymc_experimental/tests/model/transforms/test_autoreparam.py @@ -70,9 +70,9 @@ def test_multilevel(): # multilevel modelling a = pm.Normal("a") s = pm.HalfNormal("s") - a_g = pm.Normal("a_g", a, s, dims="level") + a_g = pm.Normal("a_g", a, s, shape=(2,), dims="level") s_g = pm.HalfNormal("s_g") - a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level")) + a_ig = pm.Normal("a_ig", a_g, s_g, shape=(2, 2), dims=("county", "level")) model_r, vip = vip_reparametrize(model, ["a_g", "a_ig"]) assert "a_g" in vip.get_lambda() diff --git a/pymc_experimental/tests/statespace/test_distributions.py b/pymc_experimental/tests/statespace/test_distributions.py index 1da4be60c..deddcb31a 100644 --- a/pymc_experimental/tests/statespace/test_distributions.py +++ b/pymc_experimental/tests/statespace/test_distributions.py @@ -46,7 +46,7 @@ def data(): @pytest.fixture(scope="session") def pymc_model(data): with pm.Model() as mod: - data = pm.ConstantData("data", data.values) + data = pm.Data("data", data.values) P0_diag = pm.Exponential("P0_diag", 1, shape=(2,)) P0 = pm.Deterministic("P0", pt.diag(P0_diag)) initial_trend = pm.Normal("initial_trend", shape=(2,)) @@ -172,7 +172,7 @@ def test_lgss_with_time_varying_inputs(output_name, rng): } with pm.Model(coords=coords): - exog_data = pm.MutableData("data_exog", X) + exog_data = pm.Data("data_exog", X) P0_diag = pm.Exponential("P0_diag", 1, shape=(mod.k_states,)) P0 = pm.Deterministic("P0", pt.diag(P0_diag)) initial_trend = pm.Normal("initial_trend", shape=(2,)) diff --git a/pymc_experimental/tests/statespace/test_statespace.py b/pymc_experimental/tests/statespace/test_statespace.py index c6f73449f..29a654d35 100644 --- a/pymc_experimental/tests/statespace/test_statespace.py +++ b/pymc_experimental/tests/statespace/test_statespace.py @@ -117,7 +117,7 @@ def exog_pymc_mod(exog_ss_mod, rng): X = rng.normal(size=(100, 3)).astype(floatX) with pm.Model(coords=exog_ss_mod.coords) as m: - exog_data = pm.MutableData("data_exog", X) + exog_data = pm.Data("data_exog", X) initial_trend = pm.Normal("initial_trend", dims=["trend_state"]) P0_sigma = pm.Exponential("P0_sigma", 1) P0 = pm.Deterministic( diff --git a/pymc_experimental/tests/statespace/test_structural.py b/pymc_experimental/tests/statespace/test_structural.py index 4af329e64..63457c83d 100644 --- a/pymc_experimental/tests/statespace/test_structural.py +++ b/pymc_experimental/tests/statespace/test_structural.py @@ -750,7 +750,7 @@ def test_filter_scans_time_varying_design_matrix(rng): mod = reg.build(verbose=False) with pm.Model(coords=mod.coords) as m: - data_exog = pm.MutableData("data_exog", data.values) + data_exog = pm.Data("data_exog", data.values) x0 = pm.Normal("x0", dims=["state"]) P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"]) @@ -781,7 +781,7 @@ def test_extract_components_from_idata(rng): mod = (ll + season + reg + me).build(verbose=False) with pm.Model(coords=mod.coords) as m: - data_exog = pm.MutableData("data_exog", data.values) + data_exog = pm.Data("data_exog", data.values) x0 = pm.Normal("x0", dims=["state"]) P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"]) diff --git a/pymc_experimental/tests/test_blackjax_smc.py b/pymc_experimental/tests/test_blackjax_smc.py index ebb71f132..2cdcf0671 100644 --- a/pymc_experimental/tests/test_blackjax_smc.py +++ b/pymc_experimental/tests/test_blackjax_smc.py @@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate(): model = fast_model() population = {"x": np.array([2, 3, 4])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) - jax.tree_map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])]) + jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])]) def test_blackjax_particles_from_pymc_population_multivariate(): @@ -144,7 +144,7 @@ def test_blackjax_particles_from_pymc_population_multivariate(): population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) - jax.tree_map( + jax.tree.map( np.testing.assert_allclose, blackjax_particles, [np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])], @@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable(): population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) - jax.tree_map( + jax.tree.map( np.testing.assert_allclose, blackjax_particles, [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])], @@ -181,7 +181,7 @@ def test_arviz_from_particles(): with model: inference_data = arviz_from_particles(model, particles) - assert inference_data.posterior.dims == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2}) + assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2}) assert inference_data.posterior.data_vars.dtypes == Frozen( {"x": dtype("float64"), "z": dtype("float64")} ) @@ -196,7 +196,7 @@ def test_get_jaxified_logprior(): """ logprior = get_jaxified_logprior(fast_model()) for point in [-0.5, 0.0, 0.5]: - jax.tree_map( + jax.tree.map( np.testing.assert_allclose, jax.vmap(logprior)([np.array([point])]), np.log(scipy.stats.norm(0, 1).pdf(point)), @@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood(): """ loglikelihood = get_jaxified_loglikelihood(fast_model()) for point in [-0.5, 0.0, 0.5]: - jax.tree_map( + jax.tree.map( np.testing.assert_allclose, jax.vmap(loglikelihood)([np.array([point])]), np.log(scipy.stats.norm(point, 1).pdf(0)), diff --git a/pymc_experimental/tests/test_linearmodel.py b/pymc_experimental/tests/test_linearmodel.py index 1a169c4ad..d969dbef2 100644 --- a/pymc_experimental/tests/test_linearmodel.py +++ b/pymc_experimental/tests/test_linearmodel.py @@ -142,8 +142,8 @@ def test_predict_posterior(fitted_linear_model_instance, combined): n_pred = 150 X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=n_pred)}) pred = model.predict_posterior(X_pred, combined=combined) - chains = model.idata.sample_stats.dims["chain"] - draws = model.idata.sample_stats.dims["draw"] + chains = model.idata.sample_stats.sizes["chain"] + draws = model.idata.sample_stats.sizes["draw"] expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred) assert pred.shape == expected_shape assert np.issubdtype(pred.dtype, np.floating) diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 3f769e548..775f27302 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -100,8 +100,8 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None): with pm.Model(coords=coords) as self.model: if model_config is None: model_config = self.model_config - x = pm.MutableData("x", self.X["input"].values) - y_data = pm.MutableData("y_data", self.y) + x = pm.Data("x", self.X["input"].values) + y_data = pm.Data("y_data", self.y) # prior parameters a_loc = model_config["a"]["loc"] @@ -238,8 +238,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined): pred = fitted_model_instance.sample_posterior_predictive( prediction_data["input"], combined=combined, extend_idata=True ) - chains = fitted_model_instance.idata.sample_stats.dims["chain"] - draws = fitted_model_instance.idata.sample_stats.dims["draw"] + chains = fitted_model_instance.idata.sample_stats.sizes["chain"] + draws = fitted_model_instance.idata.sample_stats.sizes["draw"] expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred) assert pred[fitted_model_instance.output_var].shape == expected_shape assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating) diff --git a/pyproject.toml b/pyproject.toml index 870b53b9a..2a5ee3247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,9 @@ filterwarnings =[ # JAX issues an over-eager warning if os.fork() is called when the JAX module is loaded, even if JAX isn't being used 'ignore:os\.fork\(\) was called\.:RuntimeWarning', + + # Warning coming from blackjax + 'ignore:jax\.tree_map is deprecated:DeprecationWarning', ] [tool.black] diff --git a/requirements.txt b/requirements.txt index 48828b8ef..cf1d063bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.11.0 +pymc>=5.13.0 scikit-learn diff --git a/setup.py b/setup.py index 0aa8206b9..92c0ea397 100644 --- a/setup.py +++ b/setup.py @@ -29,10 +29,9 @@ "Development Status :: 5 - Production/Stable", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", @@ -93,7 +92,7 @@ def read_version(): # package_data={'docs': ['*']}, include_package_data=True, classifiers=classifiers, - python_requires=">=3.8", + python_requires=">=3.10", install_requires=install_reqs, extras_require=extras_require, )