Skip to content

Minibatch warning fix #7749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from pytensor.compile import DeepCopyOp, Function, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.variable import TensorConstant, TensorVariable
Expand Down Expand Up @@ -1241,15 +1241,13 @@ def register_rv(
self.add_named_variable(rv_var, dims)
self.set_initval(rv_var, initval)
else:
if (
isinstance(observed, TensorVariable)
and observed.owner is not None
and isinstance(observed.owner.op, MinibatchOp)
and total_size is None
):
warnings.warn(
f"total_size not provided for observed variable `{name}` that uses pm.Minibatch"
)
if total_size is None and isinstance(observed, TensorVariable):
for node in ancestors([observed]):
if node.owner is not None and isinstance(node.owner.op, MinibatchOp):
warnings.warn(
f"total_size not provided for observed variable `{name}` that uses pm.Minibatch"
)
break
if not is_valid_observed(observed):
raise TypeError(
"Variables that depend on other nodes cannot be used for observed data."
Expand Down
Loading