Skip to content

Reject formulas as keys of ParamResolvers #5384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class ParamResolver:
Attributes:
param_dict: A dictionary from the ParameterValue key (str) to its
assigned value.

Raises:
TypeError if formulas are passed as keys.
"""

def __new__(cls, param_dict: 'cirq.ParamResolverOrSimilarType' = None):
Expand All @@ -68,6 +71,9 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None

self._param_hash: Optional[int] = None
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
for key in self.param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
self._deep_eval_map: ParamDictType = {}

def value_of(
Expand Down
20 changes: 3 additions & 17 deletions cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,13 @@ def test_param_dict_iter():


def test_formulas_in_param_dict():
"""Test formulas in a `param_dict`.

Param dicts are allowed to have str or sympy.Symbol as keys and
floats or sympy.Symbol as values. This should not be a common use case,
but this tests makes sure something reasonable is returned when
mixing these types and using formulas in ParamResolvers.

Note that sympy orders expressions for deterministic resolution, so
depending on the operands sent to sub(), the expression may not fully
resolve if it needs to take several iterations of resolution.
"""
"""Tests that formula keys are rejected in a `param_dict`."""
a = sympy.Symbol('a')
b = sympy.Symbol('b')
c = sympy.Symbol('c')
e = sympy.Symbol('e')
r = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})
assert sympy.Eq(r.value_of('a'), 3)
assert sympy.Eq(r.value_of('b'), 2)
assert sympy.Eq(r.value_of(b + c), 101)
assert sympy.Eq(r.value_of('c'), c)
assert sympy.Eq(r.value_of('d'), 2 * e)
with pytest.raises(TypeError, match='formula'):
_ = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})


def test_recursive_evaluation():
Expand Down