Skip to content

Rename create_act_on_args to create_simulation_state #5299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
self.grouping = grouping
super().__init__(noise=noise, seed=seed)

def _create_partial_act_on_args(
def _create_partial_simulation_state(
self,
initial_state: Union[int, 'MPSState'],
qubits: Sequence['cirq.Qid'],
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_simulation_state():
mps_simulator = ccq.mps_simulator.MPSSimulator()
ref_simulator = cirq.Simulator()
for initial_state in range(4):
args = mps_simulator._create_act_on_args(initial_state=initial_state, qubits=(q0, q1))
args = mps_simulator._create_simulation_state(initial_state=initial_state, qubits=(q0, q1))
actual = mps_simulator.simulate(circuit, qubit_order=qubit_order, initial_state=args)
expected = ref_simulator.simulate(
circuit, qubit_order=qubit_order, initial_state=initial_state
Expand Down
216 changes: 108 additions & 108 deletions cirq-core/cirq/ops/common_gates_test.py

Large diffs are not rendered by default.

52 changes: 33 additions & 19 deletions cirq-core/cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing_extensions import Protocol

from cirq import ops
from cirq import _compat, ops
from cirq._doc import doc_private
from cirq.type_workarounds import NotImplementedType

Expand All @@ -32,8 +32,8 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> Union[NotImplemente
"""Applies an action to the given argument, if it is a supported type.

For example, unitary operations can implement an `_act_on_` method that
checks if `isinstance(args, cirq.StateVectorSimulationState)` and, if so,
apply their unitary effect to the state vector.
checks if `isinstance(sim_state, cirq.StateVectorSimulationState)` and,
if so, apply their unitary effect to the state vector.

The global `cirq.act_on` method looks for whether or not the given
argument has this value, before attempting any fallback strategies
Expand Down Expand Up @@ -64,8 +64,8 @@ def _act_on_(
"""Applies an action to the given argument, if it is a supported type.

For example, unitary operations can implement an `_act_on_` method that
checks if `isinstance(args, cirq.StateVectorSimulationState)` and, if so,
apply their unitary effect to the state vector.
checks if `isinstance(sim_state, cirq.StateVectorSimulationState)` and,
if so, apply their unitary effect to the state vector.

The global `cirq.act_on` method looks for whether or not the given
argument has this value, before attempting any fallback strategies
Expand All @@ -86,9 +86,22 @@ def _act_on_(
"""


def _fix_deprecated_args(args, kwargs):
kwargs['sim_state'] = kwargs['args']
del kwargs['args']
return args, kwargs


@_compat.deprecated_parameter(
deadline='v0.16',
fix='Change argument name to `sim_state`',
parameter_desc='args',
match=lambda args, kwargs: 'args' in kwargs,
rewrite=_fix_deprecated_args,
)
def act_on(
action: Any,
args: 'cirq.SimulationStateBase',
sim_state: 'cirq.SimulationStateBase',
qubits: Sequence['cirq.Qid'] = None,
*,
allow_decompose: bool = True,
Expand All @@ -108,24 +121,25 @@ def act_on(

Args:
action: The operation, gate, or other to apply to the state tensor.
args: A mutable state object that should be modified by the action. May
specify an `_act_on_fallback_` method to use in case the action
doesn't recognize it.
sim_state: A mutable state object that should be modified by the
action. May specify an `_act_on_fallback_` method to use in case
the action doesn't recognize it.
qubits: The sequence of qubits to use when applying the action.
allow_decompose: Defaults to True. Forwarded into the
`_act_on_fallback_` method of `args`. Determines if decomposition
should be used or avoided when attempting to act `action` on `args`.
Used by internal methods to avoid redundant decompositions.
`_act_on_fallback_` method of `sim_state`. Determines if
decomposition should be used or avoided when attempting to act
`action` on `sim_state`. Used by internal methods to avoid
redundant decompositions.

Returns:
Nothing. Results are communicated by editing `args`.
Nothing. Results are communicated by editing `sim_state`.

Raises:
ValueError: If called on an operation and supplied qubits, if not called
on an operation and no qubits are supplied, or if `_act_on_` or
`_act_on_fallback_` returned something other than `True` or
`NotImplemented`.
TypeError: Failed to act `action` on `args`.
TypeError: Failed to act `action` on `sim_state`.
"""
is_op = isinstance(action, ops.Operation)

Expand All @@ -137,7 +151,7 @@ def act_on(

action_act_on = getattr(action, '_act_on_', None)
if action_act_on is not None:
result = action_act_on(args) if is_op else action_act_on(args, qubits)
result = action_act_on(sim_state) if is_op else action_act_on(sim_state, qubits)
if result is True:
return
if result is not NotImplemented:
Expand All @@ -146,7 +160,7 @@ def act_on(
f'{result!r} from {action!r}._act_on_'
)

arg_fallback = getattr(args, '_act_on_fallback_', None)
arg_fallback = getattr(sim_state, '_act_on_fallback_', None)
if arg_fallback is not None:
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
Expand All @@ -155,14 +169,14 @@ def act_on(
if result is not NotImplemented:
raise ValueError(
f'_act_on_fallback_ must return True or NotImplemented but got '
f'{result!r} from {type(args)}._act_on_fallback_'
f'{result!r} from {type(sim_state)}._act_on_fallback_'
)

raise TypeError(
"Failed to act action on state argument.\n"
"Tried both action._act_on_ and args._act_on_fallback_.\n"
"Tried both action._act_on_ and sim_state._act_on_fallback_.\n"
"\n"
f"State argument type: {type(args)}\n"
f"State argument type: {type(sim_state)}\n"
f"Action type: {type(action)}\n"
f"Action repr: {action!r}\n"
)
30 changes: 18 additions & 12 deletions cirq-core/cirq/protocols/act_on_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@ def _act_on_fallback_(


def test_act_on_fallback_succeeds():
args = DummySimulationState(fallback_result=True)
cirq.act_on(op, args)
state = DummySimulationState(fallback_result=True)
cirq.act_on(op, state)


def test_act_on_fallback_fails():
args = DummySimulationState(fallback_result=NotImplemented)
state = DummySimulationState(fallback_result=NotImplemented)
with pytest.raises(TypeError, match='Failed to act'):
cirq.act_on(op, args)
cirq.act_on(op, state)


def test_act_on_fallback_errors():
args = DummySimulationState(fallback_result=False)
state = DummySimulationState(fallback_result=False)
with pytest.raises(ValueError, match='_act_on_fallback_ must return True or NotImplemented'):
cirq.act_on(op, args)
cirq.act_on(op, state)


def test_act_on_errors():
Expand All @@ -71,9 +71,9 @@ def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf:
def _act_on_(self, sim_state):
return False

args = DummySimulationState(fallback_result=True)
state = DummySimulationState(fallback_result=True)
with pytest.raises(ValueError, match='_act_on_ must return True or NotImplemented'):
cirq.act_on(Op(), args)
cirq.act_on(Op(), state)


def test_qubits_not_allowed_for_operations():
Expand All @@ -85,14 +85,20 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf:
pass

args = DummySimulationState()
state = DummySimulationState()
with pytest.raises(
ValueError, match='Calls to act_on should not supply qubits if the action is an Operation'
):
cirq.act_on(Op(), args, qubits=[])
cirq.act_on(Op(), state, qubits=[])


def test_qubits_should_be_defined_for_operations():
args = DummySimulationState()
state = DummySimulationState()
with pytest.raises(ValueError, match='Calls to act_on should'):
cirq.act_on(cirq.KrausChannel([np.array([[1, 0], [0, 0]])]), args, qubits=None)
cirq.act_on(cirq.KrausChannel([np.array([[1, 0], [0, 0]])]), state, qubits=None)


def test_args_deprecated():
args = DummySimulationState(fallback_result=True)
with cirq.testing.assert_deprecated(deadline='v0.16'):
cirq.act_on(action=op, args=args) # pylint: disable=no-value-for-parameter
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def is_supported_operation(op: 'cirq.Operation') -> bool:
# TODO: support more general Pauli measurements
return protocols.has_stabilizer_effect(op)

def _create_partial_act_on_args(
def _create_partial_simulation_state(
self,
initial_state: Union[int, 'cirq.StabilizerChFormSimulationState'],
qubits: Sequence['cirq.Qid'],
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/clifford/clifford_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_simulation_state():
circuit.append(cirq.X(q1))
circuit.append(cirq.measure(q0, q1))

args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1))
args = simulator._create_simulation_state(initial_state=1, qubits=(q0, q1))
result = simulator.simulate(circuit, initial_state=args)
expected_state = np.zeros(shape=(2, 2))
expected_state[b0][1 - b1] = 1.0
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
if dtype not in {np.complex64, np.complex128}:
raise ValueError(f'dtype must be complex64 or complex128, was {dtype}')

def _create_partial_act_on_args(
def _create_partial_simulation_state(
self,
initial_state: Union[
np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.DensityMatrixSimulationState'
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/density_matrix_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def test_simulation_state(dtype: Type[np.number], split: bool):
for b0 in [0, 1]:
for b1 in [0, 1]:
circuit = cirq.Circuit((cirq.X**b0)(q0), (cirq.X**b1)(q1))
args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1))
args = simulator._create_simulation_state(initial_state=1, qubits=(q0, q1))
result = simulator.simulate(circuit, initial_state=args)
expected_density_matrix = np.zeros(shape=(4, 4))
expected_density_matrix[b0 * 2 + 1 - b1, b0 * 2 + 1 - b1] = 1.0
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/density_matrix_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_to_valid_density_matrix_on_simulator_output(seed, dtype, split):


def test_factor_validation():
args = cirq.DensityMatrixSimulator()._create_act_on_args(0, qubits=cirq.LineQubit.range(2))
args = cirq.DensityMatrixSimulator()._create_simulation_state(0, qubits=cirq.LineQubit.range(2))
args.apply_operation(cirq.H(cirq.LineQubit(0)))
t = args.create_merged_state().target_tensor
cirq.linalg.transformations.factor_density_matrix(t, [0])
Expand Down
Loading