diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 3335bcb7794..6dc3f88a224 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -83,7 +83,7 @@ def __init__( self.grouping = grouping super().__init__(noise=noise, seed=seed) - def _create_partial_act_on_args( + def _create_partial_simulation_state( self, initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 0d20cf9047e..cd7d3649e58 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -189,7 +189,7 @@ def test_simulation_state(): mps_simulator = ccq.mps_simulator.MPSSimulator() ref_simulator = cirq.Simulator() for initial_state in range(4): - args = mps_simulator._create_act_on_args(initial_state=initial_state, qubits=(q0, q1)) + args = mps_simulator._create_simulation_state(initial_state=initial_state, qubits=(q0, q1)) actual = mps_simulator.simulate(circuit, qubit_order=qubit_order, initial_state=args) expected = ref_simulator.simulate( circuit, qubit_order=qubit_order, initial_state=initial_state diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index 4aeb7207034..bead8592348 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -293,33 +293,33 @@ def test_x_act_on_tableau(): original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23) - args = cirq.CliffordTableauSimulationState( + state = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), ) - cirq.act_on(cirq.X**0.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.X**0.5, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau + cirq.act_on(cirq.X**0.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.X**0.5, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau - cirq.act_on(cirq.X, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == original_tableau + cirq.act_on(cirq.X, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == original_tableau - cirq.act_on(cirq.X**3.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.X**3.5, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau + cirq.act_on(cirq.X**3.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.X**3.5, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau - cirq.act_on(cirq.X**2, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau + cirq.act_on(cirq.X**2, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau foo = sympy.Symbol('foo') with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.X**foo, args, [cirq.LineQubit(1)]) + cirq.act_on(cirq.X**foo, state, [cirq.LineQubit(1)]) class iZGate(cirq.testing.SingleQubitGate): @@ -342,36 +342,36 @@ def test_y_act_on_tableau(): original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23) - args = cirq.CliffordTableauSimulationState( + state = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), ) - cirq.act_on(cirq.Y**0.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.Y**0.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(iZGate(), args, [cirq.LineQubit(1)]) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau + cirq.act_on(cirq.Y**0.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.Y**0.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(iZGate(), state, [cirq.LineQubit(1)]) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau - cirq.act_on(cirq.Y, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(iZGate(), args, [cirq.LineQubit(1)], allow_decompose=True) - assert args.log_of_measurement_results == {} - assert args.tableau == original_tableau + cirq.act_on(cirq.Y, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(iZGate(), state, [cirq.LineQubit(1)], allow_decompose=True) + assert state.log_of_measurement_results == {} + assert state.tableau == original_tableau - cirq.act_on(cirq.Y**3.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.Y**3.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(iZGate(), args, [cirq.LineQubit(1)]) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau + cirq.act_on(cirq.Y**3.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.Y**3.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(iZGate(), state, [cirq.LineQubit(1)]) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau - cirq.act_on(cirq.Y**2, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau + cirq.act_on(cirq.Y**2, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau foo = sympy.Symbol('foo') with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.Y**foo, args, [cirq.LineQubit(1)]) + cirq.act_on(cirq.Y**foo, state, [cirq.LineQubit(1)]) def test_z_h_act_on_tableau(): @@ -382,49 +382,49 @@ def test_z_h_act_on_tableau(): original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23) - args = cirq.CliffordTableauSimulationState( + state = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), ) - cirq.act_on(cirq.H, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.Z**0.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.Z**0.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.H, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau - - cirq.act_on(cirq.H, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.Z, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.H, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == original_tableau - - cirq.act_on(cirq.H, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.Z**3.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.Z**3.5, args, [cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.H, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau - - cirq.act_on(cirq.Z**2, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau - - cirq.act_on(cirq.H**2, args, [cirq.LineQubit(1)], allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == flipped_tableau + cirq.act_on(cirq.H, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.Z**0.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.Z**0.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.H, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau + + cirq.act_on(cirq.H, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.Z, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.H, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == original_tableau + + cirq.act_on(cirq.H, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.Z**3.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.Z**3.5, state, [cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.H, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau + + cirq.act_on(cirq.Z**2, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau + + cirq.act_on(cirq.H**2, state, [cirq.LineQubit(1)], allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == flipped_tableau foo = sympy.Symbol('foo') with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.Z**foo, args, [cirq.LineQubit(1)]) + cirq.act_on(cirq.Z**foo, state, [cirq.LineQubit(1)]) with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.H**foo, args, [cirq.LineQubit(1)]) + cirq.act_on(cirq.H**foo, state, [cirq.LineQubit(1)]) with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.H**1.5, args, [cirq.LineQubit(1)]) + cirq.act_on(cirq.H**1.5, state, [cirq.LineQubit(1)]) def test_cx_act_on_tableau(): @@ -432,22 +432,22 @@ def test_cx_act_on_tableau(): cirq.act_on(cirq.CX, DummySimulationState(), qubits=()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) - args = cirq.CliffordTableauSimulationState( + state = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), ) - cirq.act_on(cirq.CX, args, cirq.LineQubit.range(2), allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau.stabilizers() == [ + cirq.act_on(cirq.CX, state, cirq.LineQubit.range(2), allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau.stabilizers() == [ cirq.DensePauliString('ZIIII', coefficient=-1), cirq.DensePauliString('ZZIII', coefficient=-1), cirq.DensePauliString('IIZII', coefficient=-1), cirq.DensePauliString('IIIZI', coefficient=-1), cirq.DensePauliString('IIIIZ', coefficient=-1), ] - assert args.tableau.destabilizers() == [ + assert state.tableau.destabilizers() == [ cirq.DensePauliString('XXIII', coefficient=1), cirq.DensePauliString('IXIII', coefficient=1), cirq.DensePauliString('IIXII', coefficient=1), @@ -455,20 +455,20 @@ def test_cx_act_on_tableau(): cirq.DensePauliString('IIIIX', coefficient=1), ] - cirq.act_on(cirq.CX, args, cirq.LineQubit.range(2), allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == original_tableau + cirq.act_on(cirq.CX, state, cirq.LineQubit.range(2), allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == original_tableau - cirq.act_on(cirq.CX**4, args, cirq.LineQubit.range(2), allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == original_tableau + cirq.act_on(cirq.CX**4, state, cirq.LineQubit.range(2), allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == original_tableau foo = sympy.Symbol('foo') with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.CX**foo, args, cirq.LineQubit.range(2)) + cirq.act_on(cirq.CX**foo, state, cirq.LineQubit.range(2)) with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.CX**1.5, args, cirq.LineQubit.range(2)) + cirq.act_on(cirq.CX**1.5, state, cirq.LineQubit.range(2)) def test_cz_act_on_tableau(): @@ -476,22 +476,22 @@ def test_cz_act_on_tableau(): cirq.act_on(cirq.CZ, DummySimulationState(), qubits=()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) - args = cirq.CliffordTableauSimulationState( + state = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), ) - cirq.act_on(cirq.CZ, args, cirq.LineQubit.range(2), allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau.stabilizers() == [ + cirq.act_on(cirq.CZ, state, cirq.LineQubit.range(2), allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau.stabilizers() == [ cirq.DensePauliString('ZIIII', coefficient=-1), cirq.DensePauliString('IZIII', coefficient=-1), cirq.DensePauliString('IIZII', coefficient=-1), cirq.DensePauliString('IIIZI', coefficient=-1), cirq.DensePauliString('IIIIZ', coefficient=-1), ] - assert args.tableau.destabilizers() == [ + assert state.tableau.destabilizers() == [ cirq.DensePauliString('XZIII', coefficient=1), cirq.DensePauliString('ZXIII', coefficient=1), cirq.DensePauliString('IIXII', coefficient=1), @@ -499,44 +499,44 @@ def test_cz_act_on_tableau(): cirq.DensePauliString('IIIIX', coefficient=1), ] - cirq.act_on(cirq.CZ, args, cirq.LineQubit.range(2), allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == original_tableau + cirq.act_on(cirq.CZ, state, cirq.LineQubit.range(2), allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == original_tableau - cirq.act_on(cirq.CZ**4, args, cirq.LineQubit.range(2), allow_decompose=False) - assert args.log_of_measurement_results == {} - assert args.tableau == original_tableau + cirq.act_on(cirq.CZ**4, state, cirq.LineQubit.range(2), allow_decompose=False) + assert state.log_of_measurement_results == {} + assert state.tableau == original_tableau foo = sympy.Symbol('foo') with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.CZ**foo, args, cirq.LineQubit.range(2)) + cirq.act_on(cirq.CZ**foo, state, cirq.LineQubit.range(2)) with pytest.raises(TypeError, match="Failed to act action on state"): - cirq.act_on(cirq.CZ**1.5, args, cirq.LineQubit.range(2)) + cirq.act_on(cirq.CZ**1.5, state, cirq.LineQubit.range(2)) def test_cz_act_on_equivalent_to_h_cx_h_tableau(): - args1 = cirq.CliffordTableauSimulationState( + state1 = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=2), qubits=cirq.LineQubit.range(2), prng=np.random.RandomState(), ) - args2 = cirq.CliffordTableauSimulationState( + state2 = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=2), qubits=cirq.LineQubit.range(2), prng=np.random.RandomState(), ) - cirq.act_on(cirq.S, args=args1, qubits=[cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.S, args=args2, qubits=[cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.S, sim_state=state1, qubits=[cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.S, sim_state=state2, qubits=[cirq.LineQubit(1)], allow_decompose=False) - # Args1 uses H*CNOT*H - cirq.act_on(cirq.H, args=args1, qubits=[cirq.LineQubit(1)], allow_decompose=False) - cirq.act_on(cirq.CNOT, args=args1, qubits=cirq.LineQubit.range(2), allow_decompose=False) - cirq.act_on(cirq.H, args=args1, qubits=[cirq.LineQubit(1)], allow_decompose=False) - # Args2 uses CZ - cirq.act_on(cirq.CZ, args=args2, qubits=cirq.LineQubit.range(2), allow_decompose=False) + # state1 uses H*CNOT*H + cirq.act_on(cirq.H, sim_state=state1, qubits=[cirq.LineQubit(1)], allow_decompose=False) + cirq.act_on(cirq.CNOT, sim_state=state1, qubits=cirq.LineQubit.range(2), allow_decompose=False) + cirq.act_on(cirq.H, sim_state=state1, qubits=[cirq.LineQubit(1)], allow_decompose=False) + # state2 uses CZ + cirq.act_on(cirq.CZ, sim_state=state2, qubits=cirq.LineQubit.range(2), allow_decompose=False) - assert args1.tableau == args2.tableau + assert state1.tableau == state2.tableau foo = sympy.Symbol('foo') @@ -583,7 +583,7 @@ def test_act_on_ch_form(input_gate_sequence, outcome): else: assert num_qubits == 2 qubits = cirq.LineQubit.range(2) - args = cirq.StabilizerChFormSimulationState( + state = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(2), prng=np.random.RandomState(), initial_state=original_state.copy(), @@ -594,17 +594,17 @@ def test_act_on_ch_form(input_gate_sequence, outcome): if outcome == 'Error': with pytest.raises(TypeError, match="Failed to act action on state"): for input_gate in input_gate_sequence: - cirq.act_on(input_gate, args, qubits) + cirq.act_on(input_gate, state, qubits) return for input_gate in input_gate_sequence: - cirq.act_on(input_gate, args, qubits) + cirq.act_on(input_gate, state, qubits) if outcome == 'Original': - np.testing.assert_allclose(args.state.state_vector(), original_state.state_vector()) + np.testing.assert_allclose(state.state.state_vector(), original_state.state_vector()) if outcome == 'Flipped': - np.testing.assert_allclose(args.state.state_vector(), flipped_state.state_vector()) + np.testing.assert_allclose(state.state.state_vector(), flipped_state.state_vector()) @pytest.mark.parametrize( diff --git a/cirq-core/cirq/protocols/act_on_protocol.py b/cirq-core/cirq/protocols/act_on_protocol.py index 5080509ad66..c250a2351dc 100644 --- a/cirq-core/cirq/protocols/act_on_protocol.py +++ b/cirq-core/cirq/protocols/act_on_protocol.py @@ -16,7 +16,7 @@ from typing_extensions import Protocol -from cirq import ops +from cirq import _compat, ops from cirq._doc import doc_private from cirq.type_workarounds import NotImplementedType @@ -32,8 +32,8 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> Union[NotImplemente """Applies an action to the given argument, if it is a supported type. For example, unitary operations can implement an `_act_on_` method that - checks if `isinstance(args, cirq.StateVectorSimulationState)` and, if so, - apply their unitary effect to the state vector. + checks if `isinstance(sim_state, cirq.StateVectorSimulationState)` and, + if so, apply their unitary effect to the state vector. The global `cirq.act_on` method looks for whether or not the given argument has this value, before attempting any fallback strategies @@ -64,8 +64,8 @@ def _act_on_( """Applies an action to the given argument, if it is a supported type. For example, unitary operations can implement an `_act_on_` method that - checks if `isinstance(args, cirq.StateVectorSimulationState)` and, if so, - apply their unitary effect to the state vector. + checks if `isinstance(sim_state, cirq.StateVectorSimulationState)` and, + if so, apply their unitary effect to the state vector. The global `cirq.act_on` method looks for whether or not the given argument has this value, before attempting any fallback strategies @@ -86,9 +86,22 @@ def _act_on_( """ +def _fix_deprecated_args(args, kwargs): + kwargs['sim_state'] = kwargs['args'] + del kwargs['args'] + return args, kwargs + + +@_compat.deprecated_parameter( + deadline='v0.16', + fix='Change argument name to `sim_state`', + parameter_desc='args', + match=lambda args, kwargs: 'args' in kwargs, + rewrite=_fix_deprecated_args, +) def act_on( action: Any, - args: 'cirq.SimulationStateBase', + sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid'] = None, *, allow_decompose: bool = True, @@ -108,24 +121,25 @@ def act_on( Args: action: The operation, gate, or other to apply to the state tensor. - args: A mutable state object that should be modified by the action. May - specify an `_act_on_fallback_` method to use in case the action - doesn't recognize it. + sim_state: A mutable state object that should be modified by the + action. May specify an `_act_on_fallback_` method to use in case + the action doesn't recognize it. qubits: The sequence of qubits to use when applying the action. allow_decompose: Defaults to True. Forwarded into the - `_act_on_fallback_` method of `args`. Determines if decomposition - should be used or avoided when attempting to act `action` on `args`. - Used by internal methods to avoid redundant decompositions. + `_act_on_fallback_` method of `sim_state`. Determines if + decomposition should be used or avoided when attempting to act + `action` on `sim_state`. Used by internal methods to avoid + redundant decompositions. Returns: - Nothing. Results are communicated by editing `args`. + Nothing. Results are communicated by editing `sim_state`. Raises: ValueError: If called on an operation and supplied qubits, if not called on an operation and no qubits are supplied, or if `_act_on_` or `_act_on_fallback_` returned something other than `True` or `NotImplemented`. - TypeError: Failed to act `action` on `args`. + TypeError: Failed to act `action` on `sim_state`. """ is_op = isinstance(action, ops.Operation) @@ -137,7 +151,7 @@ def act_on( action_act_on = getattr(action, '_act_on_', None) if action_act_on is not None: - result = action_act_on(args) if is_op else action_act_on(args, qubits) + result = action_act_on(sim_state) if is_op else action_act_on(sim_state, qubits) if result is True: return if result is not NotImplemented: @@ -146,7 +160,7 @@ def act_on( f'{result!r} from {action!r}._act_on_' ) - arg_fallback = getattr(args, '_act_on_fallback_', None) + arg_fallback = getattr(sim_state, '_act_on_fallback_', None) if arg_fallback is not None: qubits = action.qubits if isinstance(action, ops.Operation) else qubits result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose) @@ -155,14 +169,14 @@ def act_on( if result is not NotImplemented: raise ValueError( f'_act_on_fallback_ must return True or NotImplemented but got ' - f'{result!r} from {type(args)}._act_on_fallback_' + f'{result!r} from {type(sim_state)}._act_on_fallback_' ) raise TypeError( "Failed to act action on state argument.\n" - "Tried both action._act_on_ and args._act_on_fallback_.\n" + "Tried both action._act_on_ and sim_state._act_on_fallback_.\n" "\n" - f"State argument type: {type(args)}\n" + f"State argument type: {type(sim_state)}\n" f"Action type: {type(action)}\n" f"Action repr: {action!r}\n" ) diff --git a/cirq-core/cirq/protocols/act_on_protocol_test.py b/cirq-core/cirq/protocols/act_on_protocol_test.py index bc0cd4c6a4a..5ffac035b33 100644 --- a/cirq-core/cirq/protocols/act_on_protocol_test.py +++ b/cirq-core/cirq/protocols/act_on_protocol_test.py @@ -43,20 +43,20 @@ def _act_on_fallback_( def test_act_on_fallback_succeeds(): - args = DummySimulationState(fallback_result=True) - cirq.act_on(op, args) + state = DummySimulationState(fallback_result=True) + cirq.act_on(op, state) def test_act_on_fallback_fails(): - args = DummySimulationState(fallback_result=NotImplemented) + state = DummySimulationState(fallback_result=NotImplemented) with pytest.raises(TypeError, match='Failed to act'): - cirq.act_on(op, args) + cirq.act_on(op, state) def test_act_on_fallback_errors(): - args = DummySimulationState(fallback_result=False) + state = DummySimulationState(fallback_result=False) with pytest.raises(ValueError, match='_act_on_fallback_ must return True or NotImplemented'): - cirq.act_on(op, args) + cirq.act_on(op, state) def test_act_on_errors(): @@ -71,9 +71,9 @@ def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: def _act_on_(self, sim_state): return False - args = DummySimulationState(fallback_result=True) + state = DummySimulationState(fallback_result=True) with pytest.raises(ValueError, match='_act_on_ must return True or NotImplemented'): - cirq.act_on(Op(), args) + cirq.act_on(Op(), state) def test_qubits_not_allowed_for_operations(): @@ -85,14 +85,20 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: pass - args = DummySimulationState() + state = DummySimulationState() with pytest.raises( ValueError, match='Calls to act_on should not supply qubits if the action is an Operation' ): - cirq.act_on(Op(), args, qubits=[]) + cirq.act_on(Op(), state, qubits=[]) def test_qubits_should_be_defined_for_operations(): - args = DummySimulationState() + state = DummySimulationState() with pytest.raises(ValueError, match='Calls to act_on should'): - cirq.act_on(cirq.KrausChannel([np.array([[1, 0], [0, 0]])]), args, qubits=None) + cirq.act_on(cirq.KrausChannel([np.array([[1, 0], [0, 0]])]), state, qubits=None) + + +def test_args_deprecated(): + args = DummySimulationState(fallback_result=True) + with cirq.testing.assert_deprecated(deadline='v0.16'): + cirq.act_on(action=op, args=args) # pylint: disable=no-value-for-parameter diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index a2ab97d0f4e..adb64421e0e 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -67,7 +67,7 @@ def is_supported_operation(op: 'cirq.Operation') -> bool: # TODO: support more general Pauli measurements return protocols.has_stabilizer_effect(op) - def _create_partial_act_on_args( + def _create_partial_simulation_state( self, initial_state: Union[int, 'cirq.StabilizerChFormSimulationState'], qubits: Sequence['cirq.Qid'], diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index 3abb33bf279..d4129ab5da7 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -111,7 +111,7 @@ def test_simulation_state(): circuit.append(cirq.X(q1)) circuit.append(cirq.measure(q0, q1)) - args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1)) + args = simulator._create_simulation_state(initial_state=1, qubits=(q0, q1)) result = simulator.simulate(circuit, initial_state=args) expected_state = np.zeros(shape=(2, 2)) expected_state[b0][1 - b1] = 1.0 diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 09ff6f98781..fe1d96dff0b 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -145,7 +145,7 @@ def __init__( if dtype not in {np.complex64, np.complex128}: raise ValueError(f'dtype must be complex64 or complex128, was {dtype}') - def _create_partial_act_on_args( + def _create_partial_simulation_state( self, initial_state: Union[ np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.DensityMatrixSimulationState' diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 589d2379cb9..674f8dc9922 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -597,7 +597,7 @@ def test_simulation_state(dtype: Type[np.number], split: bool): for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit((cirq.X**b0)(q0), (cirq.X**b1)(q1)) - args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1)) + args = simulator._create_simulation_state(initial_state=1, qubits=(q0, q1)) result = simulator.simulate(circuit, initial_state=args) expected_density_matrix = np.zeros(shape=(4, 4)) expected_density_matrix[b0 * 2 + 1 - b1, b0 * 2 + 1 - b1] = 1.0 diff --git a/cirq-core/cirq/sim/density_matrix_utils_test.py b/cirq-core/cirq/sim/density_matrix_utils_test.py index 80e72ef07ee..7d8dd87c13c 100644 --- a/cirq-core/cirq/sim/density_matrix_utils_test.py +++ b/cirq-core/cirq/sim/density_matrix_utils_test.py @@ -358,7 +358,7 @@ def test_to_valid_density_matrix_on_simulator_output(seed, dtype, split): def test_factor_validation(): - args = cirq.DensityMatrixSimulator()._create_act_on_args(0, qubits=cirq.LineQubit.range(2)) + args = cirq.DensityMatrixSimulator()._create_simulation_state(0, qubits=cirq.LineQubit.range(2)) args.apply_operation(cirq.H(cirq.LineQubit(0))) t = args.create_merged_state().target_tensor cirq.linalg.transformations.factor_density_matrix(t, [0]) diff --git a/cirq-core/cirq/sim/simulation_product_state.py b/cirq-core/cirq/sim/simulation_product_state.py index c42a07103d9..0e2c8e1f5ec 100644 --- a/cirq-core/cirq/sim/simulation_product_state.py +++ b/cirq-core/cirq/sim/simulation_product_state.py @@ -17,7 +17,7 @@ import numpy as np -from cirq import ops, protocols, value +from cirq import _compat, ops, protocols, value from cirq.sim.simulation_state import TSimulationState from cirq.sim.simulation_state_base import SimulationStateBase @@ -25,14 +25,27 @@ import cirq +def _fix_deprecated_args(args, kwargs): + kwargs['sim_states'] = kwargs['args'] + del kwargs['args'] + return args, kwargs + + class SimulationProductState( Generic[TSimulationState], SimulationStateBase[TSimulationState], abc.Mapping ): """A container for a `Qid`-to-`SimulationState` dictionary.""" + @_compat.deprecated_parameter( + deadline='v0.16', + fix='Change argument name to `sim_states`', + parameter_desc='args', + match=lambda args, kwargs: 'args' in kwargs, + rewrite=_fix_deprecated_args, + ) def __init__( self, - args: Dict[Optional['cirq.Qid'], TSimulationState], + sim_states: Dict[Optional['cirq.Qid'], TSimulationState], qubits: Sequence['cirq.Qid'], split_untangled_states: bool, classical_data: Optional['cirq.ClassicalDataStore'] = None, @@ -40,8 +53,8 @@ def __init__( """Initializes the class. Args: - args: The `SimulationState` dictionary. This will not be copied; the - original reference will be kept here. + sim_states: The `SimulationState` dictionary. This will not be + copied; the original reference will be kept here. qubits: The canonical ordering of qubits. split_untangled_states: If True, optimizes operations by running unentangled qubit sets independently and merging those states @@ -51,12 +64,17 @@ def __init__( """ classical_data = classical_data or value.ClassicalDataDictionaryStore() super().__init__(qubits=qubits, classical_data=classical_data) - self._args = args + self._sim_states = sim_states self._split_untangled_states = split_untangled_states @property + def sim_states(self) -> Mapping[Optional['cirq.Qid'], TSimulationState]: + return self._sim_states + + @property # type: ignore + @_compat.deprecated(deadline='v0.16', fix='Use `sim_states` instead.') def args(self) -> Mapping[Optional['cirq.Qid'], TSimulationState]: - return self._args + return self._sim_states @property def split_untangled_states(self) -> bool: @@ -64,9 +82,9 @@ def split_untangled_states(self) -> bool: def create_merged_state(self) -> TSimulationState: if not self.split_untangled_states: - return self.args[None] - final_args = self.args[None] - for args in set([self.args[k] for k in self.args.keys() if k is not None]): + return self.sim_states[None] + final_args = self.sim_states[None] + for args in set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]): final_args = final_args.kronecker_product(args) return final_args.transpose_to_qubit_order(self.qubits) @@ -90,13 +108,13 @@ def _act_on_fallback_( and gate_opt.global_shift == 0 ): q0, q1 = qubits - args0 = self.args[q0] - args1 = self.args[q1] + args0 = self.sim_states[q0] + args1 = self.sim_states[q1] if args0 is args1: args0.swap(q0, q1, inplace=True) else: - self._args[q0] = args1.rename(q1, q0, inplace=True) - self._args[q1] = args0.rename(q0, q1, inplace=True) + self._sim_states[q0] = args1.rename(q1, q0, inplace=True) + self._sim_states[q1] = args0.rename(q0, q1, inplace=True) return True # Go through the op's qubits and join any disparate SimulationState states @@ -104,14 +122,14 @@ def _act_on_fallback_( op_args_opt: Optional[TSimulationState] = None for q in qubits: if op_args_opt is None: - op_args_opt = self.args[q] + op_args_opt = self.sim_states[q] elif q not in op_args_opt.qubits: - op_args_opt = op_args_opt.kronecker_product(self.args[q]) - op_args = op_args_opt or self.args[None] + op_args_opt = op_args_opt.kronecker_product(self.sim_states[q]) + op_args = op_args_opt or self.sim_states[None] # (Backfill the args map with the new value) for q in op_args.qubits: - self._args[q] = op_args + self._sim_states[q] = op_args # Act on the args with the operation act_on_qubits = qubits if isinstance(action, ops.Gate) else None @@ -124,11 +142,11 @@ def _act_on_fallback_( for q in qubits: if op_args.allows_factoring: q_args, op_args = op_args.factor((q,), validate=False) - self._args[q] = q_args + self._sim_states[q] = q_args # (Backfill the args map with the new value) for q in op_args.qubits: - self._args[q] = op_args + self._sim_states[q] = op_args return True def copy( @@ -136,11 +154,11 @@ def copy( ) -> 'cirq.SimulationProductState[TSimulationState]': classical_data = self._classical_data.copy() copies = {} - for sim_state in set(self.args.values()): + for sim_state in set(self.sim_states.values()): copies[sim_state] = sim_state.copy(deep_copy_buffers) for copy in copies.values(): copy._classical_data = classical_data - args = {q: copies[a] for q, a in self.args.items()} + args = {q: copies[a] for q, a in self.sim_states.items()} return SimulationProductState( args, self.qubits, self.split_untangled_states, classical_data=classical_data ) @@ -154,7 +172,7 @@ def sample( columns = [] selected_order: List[ops.Qid] = [] q_set = set(qubits) - for v in dict.fromkeys(self.args.values()): + for v in dict.fromkeys(self.sim_states.values()): qs = [q for q in v.qubits if q in q_set] if any(qs): column = v.sample(qs, repetitions, seed) @@ -166,10 +184,10 @@ def sample( return stacked[:, index_order] def __getitem__(self, item: Optional['cirq.Qid']) -> TSimulationState: - return self.args[item] + return self.sim_states[item] def __len__(self) -> int: - return len(self.args) + return len(self.sim_states) def __iter__(self) -> Iterator[Optional['cirq.Qid']]: - return iter(self.args) + return iter(self.sim_states) diff --git a/cirq-core/cirq/sim/simulation_product_state_test.py b/cirq-core/cirq/sim/simulation_product_state_test.py index ab8c5082950..7108f875058 100644 --- a/cirq-core/cirq/sim/simulation_product_state_test.py +++ b/cirq-core/cirq/sim/simulation_product_state_test.py @@ -54,216 +54,226 @@ def _act_on_fallback_( def create_container( qubits: Sequence['cirq.Qid'], split_untangled_states=True ) -> cirq.SimulationProductState[EmptySimulationState]: - args_map: Dict[Optional['cirq.Qid'], EmptySimulationState] = {} + state_map: Dict[Optional['cirq.Qid'], EmptySimulationState] = {} log = cirq.ClassicalDataDictionaryStore() if split_untangled_states: for q in reversed(qubits): - args_map[q] = EmptySimulationState([q], log) - args_map[None] = EmptySimulationState((), log) + state_map[q] = EmptySimulationState([q], log) + state_map[None] = EmptySimulationState((), log) else: - args = EmptySimulationState(qubits, log) + state = EmptySimulationState(qubits, log) for q in qubits: - args_map[q] = args - args_map[None] = args if not split_untangled_states else EmptySimulationState((), log) - return cirq.SimulationProductState(args_map, qubits, split_untangled_states, classical_data=log) + state_map[q] = state + state_map[None] = state if not split_untangled_states else EmptySimulationState((), log) + return cirq.SimulationProductState( + state_map, qubits, split_untangled_states, classical_data=log + ) def test_entanglement_causes_join(): - args = create_container(qs2) - assert len(set(args.values())) == 3 - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 2 - assert args[q0] is args[q1] - assert args[None] is not args[q0] + state = create_container(qs2) + assert len(set(state.values())) == 3 + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 2 + assert state[q0] is state[q1] + assert state[None] is not state[q0] def test_subcircuit_entanglement_causes_join(): - args = create_container(qs2) - assert len(set(args.values())) == 3 - args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1)))) - assert len(set(args.values())) == 2 - assert args[q0] is args[q1] + state = create_container(qs2) + assert len(set(state.values())) == 3 + state.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1)))) + assert len(set(state.values())) == 2 + assert state[q0] is state[q1] def test_subcircuit_entanglement_causes_join_in_subset(): - args = create_container(qs3) - assert len(set(args.values())) == 4 - args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1)))) - assert len(set(args.values())) == 3 - assert args[q0] is args[q1] - args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q2)))) - assert len(set(args.values())) == 2 - assert args[q0] is args[q1] is args[q2] + state = create_container(qs3) + assert len(set(state.values())) == 4 + state.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1)))) + assert len(set(state.values())) == 3 + assert state[q0] is state[q1] + state.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q2)))) + assert len(set(state.values())) == 2 + assert state[q0] is state[q1] is state[q2] def test_identity_does_not_join(): - args = create_container(qs2) - assert len(set(args.values())) == 3 - args.apply_operation(cirq.IdentityGate(2)(q0, q1)) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0] is not args[None] + state = create_container(qs2) + assert len(set(state.values())) == 3 + state.apply_operation(cirq.IdentityGate(2)(q0, q1)) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0] is not state[None] def test_identity_fallback_does_not_join(): - args = create_container(qs2) - assert len(set(args.values())) == 3 - args._act_on_fallback_(cirq.I, (q0, q1)) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0] is not args[None] + state = create_container(qs2) + assert len(set(state.values())) == 3 + state._act_on_fallback_(cirq.I, (q0, q1)) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0] is not state[None] def test_subcircuit_identity_does_not_join(): - args = create_container(qs2) - assert len(set(args.values())) == 3 - args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.IdentityGate(2)(q0, q1)))) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] + state = create_container(qs2) + assert len(set(state.values())) == 3 + state.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.IdentityGate(2)(q0, q1)))) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] def test_measurement_causes_split(): - args = create_container(qs2) - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 2 - args.apply_operation(cirq.measure(q0)) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0] is not args[None] + state = create_container(qs2) + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 2 + state.apply_operation(cirq.measure(q0)) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0] is not state[None] def test_subcircuit_measurement_causes_split(): - args = create_container(qs2) - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 2 - args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0)))) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] + state = create_container(qs2) + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 2 + state.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0)))) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] def test_subcircuit_measurement_causes_split_in_subset(): - args = create_container(qs3) - args.apply_operation(cirq.CNOT(q0, q1)) - args.apply_operation(cirq.CNOT(q0, q2)) - assert len(set(args.values())) == 2 - args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0)))) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q1)))) - assert len(set(args.values())) == 4 - assert args[q0] is not args[q1] - assert args[q0] is not args[q2] - assert args[q1] is not args[q2] + state = create_container(qs3) + state.apply_operation(cirq.CNOT(q0, q1)) + state.apply_operation(cirq.CNOT(q0, q2)) + assert len(set(state.values())) == 2 + state.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0)))) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + state.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q1)))) + assert len(set(state.values())) == 4 + assert state[q0] is not state[q1] + assert state[q0] is not state[q2] + assert state[q1] is not state[q2] def test_reset_causes_split(): - args = create_container(qs2) - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 2 - args.apply_operation(cirq.reset(q0)) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0] is not args[None] + state = create_container(qs2) + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 2 + state.apply_operation(cirq.reset(q0)) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0] is not state[None] def test_measurement_does_not_split_if_disabled(): - args = create_container(qs2, False) - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 1 - args.apply_operation(cirq.measure(q0)) - assert len(set(args.values())) == 1 - assert args[q1] is args[q0] - assert args[None] is args[q0] + state = create_container(qs2, False) + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 1 + state.apply_operation(cirq.measure(q0)) + assert len(set(state.values())) == 1 + assert state[q1] is state[q0] + assert state[None] is state[q0] def test_reset_does_not_split_if_disabled(): - args = create_container(qs2, False) - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 1 - args.apply_operation(cirq.reset(q0)) - assert len(set(args.values())) == 1 - assert args[q1] is args[q0] - assert args[None] is args[q0] + state = create_container(qs2, False) + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 1 + state.apply_operation(cirq.reset(q0)) + assert len(set(state.values())) == 1 + assert state[q1] is state[q0] + assert state[None] is state[q0] def test_measurement_of_all_qubits_causes_split(): - args = create_container(qs2) - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 2 - args.apply_operation(cirq.measure(q0, q1)) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0] is not args[None] + state = create_container(qs2) + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 2 + state.apply_operation(cirq.measure(q0, q1)) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0] is not state[None] def test_measurement_in_single_qubit_circuit_passes(): - args = create_container([q0]) - assert len(set(args.values())) == 2 - args.apply_operation(cirq.measure(q0)) - assert len(set(args.values())) == 2 - assert args[q0] is not args[None] + state = create_container([q0]) + assert len(set(state.values())) == 2 + state.apply_operation(cirq.measure(q0)) + assert len(set(state.values())) == 2 + assert state[q0] is not state[None] def test_reorder_succeeds(): - args = create_container(qs2, False) - reordered = args[q0].transpose_to_qubit_order([q1, q0]) + state = create_container(qs2, False) + reordered = state[q0].transpose_to_qubit_order([q1, q0]) assert reordered.qubits == (q1, q0) def test_copy_succeeds(): - args = create_container(qs2, False) - copied = args[q0].copy() + state = create_container(qs2, False) + copied = state[q0].copy() assert copied.qubits == (q0, q1) def test_merge_succeeds(): - args = create_container(qs2, False) - merged = args.create_merged_state() + state = create_container(qs2, False) + merged = state.create_merged_state() assert merged.qubits == (q0, q1) def test_swap_does_not_merge(): - args = create_container(qs2) - old_q0 = args[q0] - old_q1 = args[q1] - args.apply_operation(cirq.SWAP(q0, q1)) - assert len(set(args.values())) == 3 - assert args[q0] is not old_q0 - assert args[q1] is old_q0 - assert args[q1] is not old_q1 - assert args[q0] is old_q1 - assert args[q0].qubits == (q0,) - assert args[q1].qubits == (q1,) + state = create_container(qs2) + old_q0 = state[q0] + old_q1 = state[q1] + state.apply_operation(cirq.SWAP(q0, q1)) + assert len(set(state.values())) == 3 + assert state[q0] is not old_q0 + assert state[q1] is old_q0 + assert state[q1] is not old_q1 + assert state[q0] is old_q1 + assert state[q0].qubits == (q0,) + assert state[q1].qubits == (q1,) def test_half_swap_does_merge(): - args = create_container(qs2) - args.apply_operation(cirq.SWAP(q0, q1) ** 0.5) - assert len(set(args.values())) == 2 - assert args[q0] is args[q1] + state = create_container(qs2) + state.apply_operation(cirq.SWAP(q0, q1) ** 0.5) + assert len(set(state.values())) == 2 + assert state[q0] is state[q1] def test_swap_after_entangle_reorders(): - args = create_container(qs2) - args.apply_operation(cirq.CX(q0, q1)) - assert len(set(args.values())) == 2 - assert args[q0].qubits == (q0, q1) - args.apply_operation(cirq.SWAP(q0, q1)) - assert len(set(args.values())) == 2 - assert args[q0] is args[q1] - assert args[q0].qubits == (q1, q0) + state = create_container(qs2) + state.apply_operation(cirq.CX(q0, q1)) + assert len(set(state.values())) == 2 + assert state[q0].qubits == (q0, q1) + state.apply_operation(cirq.SWAP(q0, q1)) + assert len(set(state.values())) == 2 + assert state[q0] is state[q1] + assert state[q0].qubits == (q1, q0) def test_act_on_gate_does_not_join(): - args = create_container(qs2) - assert len(set(args.values())) == 3 - cirq.act_on(cirq.X, args, [q0]) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0] is not args[None] + state = create_container(qs2) + assert len(set(state.values())) == 3 + cirq.act_on(cirq.X, state, [q0]) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0] is not state[None] def test_field_getters(): - args = create_container(qs2) - assert args.args.keys() == set(qs2) | {None} - assert args.split_untangled_states + state = create_container(qs2) + assert state.sim_states.keys() == set(qs2) | {None} + assert state.split_untangled_states + + +def test_deprecated_args(): + state = create_container(qs2) + with cirq.testing.assert_deprecated(deadline='v0.16'): + _ = state.args + with cirq.testing.assert_deprecated(deadline='v0.16'): + _ = cirq.SimulationProductState(args={}, qubits=[], split_untangled_states=False) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 1c9ed276151..3e09a1e8cc2 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -684,10 +684,9 @@ def _base_iterator( StepResults from simulating a Moment of the Circuit. """ qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits()) - sim_state = self._create_act_on_args(initial_state, qubits) + sim_state = self._create_simulation_state(initial_state, qubits) return self._core_iterator(circuit, sim_state) - @abc.abstractmethod def _create_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'] ) -> TSimulatorState: @@ -706,6 +705,36 @@ def _create_act_on_args( Returns: The `TSimulatorState` for this simulator. """ + raise NotImplementedError() + + def _create_simulation_state( + self, initial_state: Any, qubits: Sequence['cirq.Qid'] + ) -> TSimulatorState: + """Creates the state for a simulator. + + Custom simulators should implement this method. + + Args: + initial_state: The initial state for the simulation. The form of + this state depends on the simulation implementation. See + documentation of the implementing class for details. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + + Returns: + The `TSimulatorState` for this simulator. + """ + _compat._warn_or_error( + '`_create_act_on_args` has been renamed to `_create_simulation_state` in the' + ' SimulatesIntermediateState interface, so simulators need to rename that method' + f' implementation as well before v0.16. {type(self)}' + ' has no `_create_simulation_state` method, so falling back to `_create_act_on_args`.' + ' This fallback functionality will be removed in v0.16.' + ) + # When cleaning this up in v0.16, mark `_create_simulation_state` as @abc.abstractmethod, + # remove this implementation, and delete `_create_act_on_args` entirely. + return self._create_act_on_args(initial_state, qubits) @abc.abstractmethod def _core_iterator( diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 74a1fba4f84..fd222bcae24 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -33,7 +33,7 @@ import numpy as np -from cirq import ops, protocols, study, value, devices +from cirq import _compat, devices, ops, protocols, study, value from cirq.sim import simulator from cirq.sim.simulation_product_state import SimulationProductState from cirq.sim.simulation_state import TSimulationState @@ -66,10 +66,10 @@ class SimulatorBase( """A base class for the built-in simulators. Most implementors of this interface should implement the - `_create_partial_act_on_args` and `_create_step_result` methods. The first - one creates the simulator's quantum state representation at the beginning - of the simulation. The second creates the step result emitted after each - `Moment` in the simulation. + `_create_partial_simulation_state` and `_create_step_result` methods. The + first one creates the simulator's quantum state representation at the + beginning of the simulation. The second creates the step result emitted + after each `Moment` in the simulation. Iteration in the subclass is handled by the `_core_iterator` implementation here, which handles moment stepping, application of operations, measurement @@ -113,7 +113,6 @@ def __init__( self.noise = devices.NoiseModel.from_noise_model_like(noise) self._split_untangled_states = split_untangled_states - @abc.abstractmethod def _create_partial_act_on_args( self, initial_state: Any, @@ -132,6 +131,37 @@ def _create_partial_act_on_args( classical_data: The shared classical data container for this simulation. """ + raise NotImplementedError() + + def _create_partial_simulation_state( + self, + initial_state: Any, + qubits: Sequence['cirq.Qid'], + classical_data: 'cirq.ClassicalDataStore', + ) -> TSimulationState: + """Creates an instance of the TSimulationState class for the simulator. + + It represents the supplied qubits initialized to the provided state. + + Args: + initial_state: The initial state to represent. An integer state is + understood to be a pure state. Other state representations are + simulator-dependent. + qubits: The sequence of qubits to represent. + classical_data: The shared classical data container for this + simulation. + """ + _compat._warn_or_error( + '`_create_partial_act_on_args` has been renamed to `_create_partial_simulation_state`' + ' in the SimulatorBase class, so simulators need to rename that method' + f' implementation as well before v0.16. {type(self)}' + ' has no `_create_partial_simulation_state` method, so falling back to' + ' `_create_partial_act_on_args`. This fallback functionality will be removed in v0.16.' + ) + # When cleaning this up in v0.16, mark `_create_partial_simulation_state` as + # @abc.abstractmethod, remove this implementation, and delete `_create_partial_act_on_args` + # entirely. + return self._create_partial_act_on_args(initial_state, qubits, classical_data) @abc.abstractmethod def _create_step_result( @@ -226,7 +256,7 @@ def _run( resolved_circuit = protocols.resolve_parameters(circuit, param_resolver) check_all_resolved(resolved_circuit) qubits = tuple(sorted(resolved_circuit.all_qubits())) - sim_state = self._create_act_on_args(0, qubits) + sim_state = self._create_simulation_state(0, qubits) prefix, general_suffix = ( split_into_matching_protocol_then_general(resolved_circuit, self._can_be_in_run_prefix) @@ -302,7 +332,7 @@ def sweep_prefixable(op: 'cirq.Operation'): qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(program.all_qubits()) initial_state = 0 if initial_state is None else initial_state - sim_state = self._create_act_on_args(initial_state, qubits) + sim_state = self._create_simulation_state(initial_state, qubits) prefix, suffix = ( split_into_matching_protocol_then_general(program, sweep_prefixable) if self._can_be_in_run_prefix(self.noise) @@ -314,7 +344,7 @@ def sweep_prefixable(op: 'cirq.Operation'): sim_state = step_result._sim_state yield from super().simulate_sweep_iter(suffix, params, qubit_order, sim_state) - def _create_act_on_args( + def _create_simulation_state( self, initial_state: Any, qubits: Sequence['cirq.Qid'] ) -> SimulationStateBase[TSimulationState]: if isinstance(initial_state, SimulationStateBase): @@ -325,24 +355,24 @@ def _create_act_on_args( args_map: Dict[Optional['cirq.Qid'], TSimulationState] = {} if isinstance(initial_state, int): for q in reversed(qubits): - args_map[q] = self._create_partial_act_on_args( + args_map[q] = self._create_partial_simulation_state( initial_state=initial_state % q.dimension, qubits=[q], classical_data=classical_data, ) initial_state = int(initial_state / q.dimension) else: - args = self._create_partial_act_on_args( + args = self._create_partial_simulation_state( initial_state=initial_state, qubits=qubits, classical_data=classical_data ) for q in qubits: args_map[q] = args - args_map[None] = self._create_partial_act_on_args(0, (), classical_data) + args_map[None] = self._create_partial_simulation_state(0, (), classical_data) return SimulationProductState( args_map, qubits, self._split_untangled_states, classical_data=classical_data ) else: - return self._create_partial_act_on_args( + return self._create_partial_simulation_state( initial_state=initial_state, qubits=qubits, classical_data=classical_data ) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index ec2042b673b..047e0bf9b9f 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -22,8 +22,8 @@ class CountingState(cirq.qis.QuantumStateRepresentation): - def __init__(self, state, gate_count=0, measurement_count=0): - self.state = state + def __init__(self, data, gate_count=0, measurement_count=0): + self.data = data self.gate_count = gate_count self.measurement_count = measurement_count @@ -35,7 +35,7 @@ def measure( def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState': return CountingState( - self.state, + self.data, self.gate_count + other.gate_count, self.measurement_count + other.measurement_count, ) @@ -43,8 +43,8 @@ def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState': def factor( self: 'CountingState', axes: Sequence[int], *, validate=True, atol=1e-07 ) -> Tuple['CountingState', 'CountingState']: - return CountingState(self.state, self.gate_count, self.measurement_count), CountingState( - self.state + return CountingState(self.data, self.gate_count, self.measurement_count), CountingState( + self.data ) def reindex(self: 'CountingState', axes: Sequence[int]) -> 'CountingState': @@ -52,7 +52,7 @@ def reindex(self: 'CountingState', axes: Sequence[int]) -> 'CountingState': def copy(self, deep_copy_buffers: bool = True) -> 'CountingState': return CountingState( - state=self.state, gate_count=self.gate_count, measurement_count=self.measurement_count + data=self.data, gate_count=self.gate_count, measurement_count=self.measurement_count ) @@ -68,8 +68,8 @@ def _act_on_fallback_( return True @property - def state(self): - return self._state.state + def data(self): + return self._state.data @property def gate_count(self): @@ -112,7 +112,7 @@ class CountingSimulator( def __init__(self, noise=None, split_untangled_states=False): super().__init__(noise=noise, split_untangled_states=split_untangled_states) - def _create_partial_act_on_args( + def _create_partial_simulation_state( self, initial_state: Any, qubits: Sequence['cirq.Qid'], @@ -142,7 +142,7 @@ class SplittableCountingSimulator(CountingSimulator): def __init__(self, noise=None, split_untangled_states=True): super().__init__(noise=noise, split_untangled_states=split_untangled_states) - def _create_partial_act_on_args( + def _create_partial_simulation_state( self, initial_state: Any, qubits: Sequence['cirq.Qid'], @@ -251,95 +251,95 @@ def test_run_non_terminal_measurement(): def test_integer_initial_state_is_split(): sim = SplittableCountingSimulator() - args = sim._create_act_on_args(2, (q0, q1)) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0].state == 1 - assert args[q1].state == 0 - assert args[None].state == 0 + state = sim._create_simulation_state(2, (q0, q1)) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0].data == 1 + assert state[q1].data == 0 + assert state[None].data == 0 def test_integer_initial_state_is_not_split_if_disabled(): sim = SplittableCountingSimulator(split_untangled_states=False) - args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, SplittableCountingSimulationState) - assert args[q0] is args[q1] - assert args.state == 2 + state = sim._create_simulation_state(2, (q0, q1)) + assert isinstance(state, SplittableCountingSimulationState) + assert state[q0] is state[q1] + assert state.data == 2 def test_integer_initial_state_is_not_split_if_impossible(): sim = CountingSimulator() - args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, CountingSimulationState) - assert not isinstance(args, SplittableCountingSimulationState) - assert args[q0] is args[q1] - assert args.state == 2 + state = sim._create_simulation_state(2, (q0, q1)) + assert isinstance(state, CountingSimulationState) + assert not isinstance(state, SplittableCountingSimulationState) + assert state[q0] is state[q1] + assert state.data == 2 def test_non_integer_initial_state_is_not_split(): sim = SplittableCountingSimulator() - args = sim._create_act_on_args(entangled_state_repr, (q0, q1)) - assert len(set(args.values())) == 2 - assert (args[q0].state == entangled_state_repr).all() - assert args[q1] is args[q0] - assert args[None].state == 0 + state = sim._create_simulation_state(entangled_state_repr, (q0, q1)) + assert len(set(state.values())) == 2 + assert (state[q0].data == entangled_state_repr).all() + assert state[q1] is state[q0] + assert state[None].data == 0 def test_entanglement_causes_join(): sim = SplittableCountingSimulator() - args = sim._create_act_on_args(2, (q0, q1)) - assert len(set(args.values())) == 3 - args.apply_operation(cirq.CNOT(q0, q1)) - assert len(set(args.values())) == 2 - assert args[q0] is args[q1] - assert args[None] is not args[q0] + state = sim._create_simulation_state(2, (q0, q1)) + assert len(set(state.values())) == 3 + state.apply_operation(cirq.CNOT(q0, q1)) + assert len(set(state.values())) == 2 + assert state[q0] is state[q1] + assert state[None] is not state[q0] def test_measurement_causes_split(): sim = SplittableCountingSimulator() - args = sim._create_act_on_args(entangled_state_repr, (q0, q1)) - assert len(set(args.values())) == 2 - args.apply_operation(cirq.measure(q0)) - assert len(set(args.values())) == 3 - assert args[q0] is not args[q1] - assert args[q0] is not args[None] + state = sim._create_simulation_state(entangled_state_repr, (q0, q1)) + assert len(set(state.values())) == 2 + state.apply_operation(cirq.measure(q0)) + assert len(set(state.values())) == 3 + assert state[q0] is not state[q1] + assert state[q0] is not state[None] def test_measurement_does_not_split_if_disabled(): sim = SplittableCountingSimulator(split_untangled_states=False) - args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, SplittableCountingSimulationState) - args.apply_operation(cirq.measure(q0)) - assert isinstance(args, SplittableCountingSimulationState) - assert args[q0] is args[q1] + state = sim._create_simulation_state(2, (q0, q1)) + assert isinstance(state, SplittableCountingSimulationState) + state.apply_operation(cirq.measure(q0)) + assert isinstance(state, SplittableCountingSimulationState) + assert state[q0] is state[q1] def test_measurement_does_not_split_if_impossible(): sim = CountingSimulator() - args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, CountingSimulationState) - assert not isinstance(args, SplittableCountingSimulationState) - args.apply_operation(cirq.measure(q0)) - assert isinstance(args, CountingSimulationState) - assert not isinstance(args, SplittableCountingSimulationState) - assert args[q0] is args[q1] + state = sim._create_simulation_state(2, (q0, q1)) + assert isinstance(state, CountingSimulationState) + assert not isinstance(state, SplittableCountingSimulationState) + state.apply_operation(cirq.measure(q0)) + assert isinstance(state, CountingSimulationState) + assert not isinstance(state, SplittableCountingSimulationState) + assert state[q0] is state[q1] def test_reorder_succeeds(): sim = SplittableCountingSimulator() - args = sim._create_act_on_args(entangled_state_repr, (q0, q1)) - reordered = args[q0].transpose_to_qubit_order([q1, q0]) + state = sim._create_simulation_state(entangled_state_repr, (q0, q1)) + reordered = state[q0].transpose_to_qubit_order([q1, q0]) assert reordered.qubits == (q1, q0) @pytest.mark.parametrize('split', [True, False]) def test_sim_state_instance_unchanged_during_normal_sim(split: bool): sim = SplittableCountingSimulator(split_untangled_states=split) - args = sim._create_act_on_args(0, (q0, q1)) + state = sim._create_simulation_state(0, (q0, q1)) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.reset(q1)) - for step in sim.simulate_moment_steps(circuit, initial_state=args): - assert step._sim_state is args - assert (step._merged_sim_state is not args) == split + for step in sim.simulate_moment_steps(circuit, initial_state=state): + assert step._sim_state is state + assert (step._merged_sim_state is not state) == split def test_measurements_retained_in_step_results(): @@ -409,3 +409,19 @@ def _create_simulator_trial_result( # type: ignore r = sim.simulate(cirq.Circuit()) assert r._final_simulator_state.gate_count == 0 assert r._final_simulator_state.measurement_count == 0 + + +def test_deprecated_create_partial_act_on_args(): + class DeprecatedSim(cirq.SimulatorBase): + def _create_partial_act_on_args(self, initial_state, qubits, classical_data): + return 0 + + def _create_step_result(self): + pass + + def _create_simulator_trial_result(self): + pass + + sim = DeprecatedSim() + with cirq.testing.assert_deprecated(deadline='v0.16'): + sim.simulate_moment_steps(cirq.Circuit()) diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 0e7e7a25db9..59b4d99e5d9 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -539,3 +539,19 @@ def test_trial_result_initializer(): assert x._final_simulator_state == 3 x = SimulationTrialResult(resolver, {}, final_simulator_state=state) assert x._final_simulator_state == 3 + + +def test_deprecated_create_act_on_args(): + class DeprecatedSim(cirq.SimulatesIntermediateState): + def _create_act_on_args(self, initial_state, qubits): + return 0 + + def _core_iterator(self, circuit, sim_state): + pass + + def _create_simulator_trial_result(self): + pass + + sim = DeprecatedSim() + with cirq.testing.assert_deprecated(deadline='v0.16'): + sim.simulate_moment_steps(cirq.Circuit()) diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index fbc8eeae29f..f15a9f44703 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -153,7 +153,7 @@ def __init__( dtype=dtype, noise=noise, seed=seed, split_untangled_states=split_untangled_states ) - def _create_partial_act_on_args( + def _create_partial_simulation_state( self, initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.StateVectorSimulationState'], qubits: Sequence['cirq.Qid'], diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index 5a9c6112928..b63e6ce0b57 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -459,7 +459,7 @@ def test_simulation_state(dtype: Type[np.number], split: bool): for b0 in [0, 1]: for b1 in [0, 1]: circuit = cirq.Circuit((cirq.X**b0)(q0), (cirq.X**b1)(q1)) - args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1)) + args = simulator._create_simulation_state(initial_state=1, qubits=(q0, q1)) result = simulator.simulate(circuit, initial_state=args) expected_state = np.zeros(shape=(2, 2)) expected_state[b0][1 - b1] = 1.0 @@ -1335,7 +1335,7 @@ def test_pure_state_creation(): sim = cirq.Simulator() qids = cirq.LineQubit.range(3) shape = cirq.qid_shape(qids) - args = sim._create_act_on_args(1, qids) + args = sim._create_simulation_state(1, qids) values = list(args.values()) arg = ( values[0] diff --git a/cirq-core/cirq/sim/state_vector_test.py b/cirq-core/cirq/sim/state_vector_test.py index 7c5928d1948..d8b4df663de 100644 --- a/cirq-core/cirq/sim/state_vector_test.py +++ b/cirq-core/cirq/sim/state_vector_test.py @@ -378,7 +378,7 @@ def test_step_result_bloch_vector(): def test_factor_validation(): - args = cirq.Simulator()._create_act_on_args(0, qubits=cirq.LineQubit.range(2)) + args = cirq.Simulator()._create_simulation_state(0, qubits=cirq.LineQubit.range(2)) args.apply_operation(cirq.H(cirq.LineQubit(0)) ** 0.7) t = args.create_merged_state().target_tensor cirq.linalg.transformations.factor_state_vector(t, [0]) diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index 61e9a37e479..58986a6eb24 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -478,7 +478,7 @@ def simulate( converted = _convert_to_circuit_with_drift(self, program) return self._simulator.simulate(converted, param_resolver, qubit_order, initial_state) - def _create_partial_act_on_args( + def _create_partial_simulation_state( self, initial_state: Union[int, cirq.StateVectorSimulationState], qubits: Sequence[cirq.Qid],