diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 2a39efe6f62..b8ea3f1f751 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -26,7 +26,7 @@ from cirq import devices, protocols, qis, value from cirq._compat import deprecated_parameter -from cirq.sim import simulator_base +from cirq.sim import simulator, simulator_base from cirq.sim.act_on_args import ActOnArgs if TYPE_CHECKING: @@ -122,7 +122,7 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'MPSSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[MPSState]', ) -> 'MPSTrialResult': """Creates a single trial results with the measurements. @@ -130,27 +130,28 @@ def _create_simulator_trial_result( params: A ParamResolver for determining values of Symbols. measurements: A dictionary from measurement key (e.g. qubit) to the actual measurement array. - final_step_result: The final step result of the simulation. + final_simulator_state: The final state of the simulation. Returns: A single result. """ return MPSTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState']): """A single trial reult""" + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'MPSSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[MPSState]', ) -> None: super().__init__( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) @property diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 5b4ce85d802..dfd241f6d27 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -1,7 +1,6 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice import itertools import math -from unittest import mock import numpy as np import pytest @@ -263,8 +262,7 @@ def test_measurement_str(): def test_trial_result_str(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState( + final_simulator_state = ccq.mps_simulator.MPSState( qubits=(q0,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(), @@ -274,7 +272,7 @@ def test_trial_result_str(): ccq.mps_simulator.MPSTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) == """measurements: m=1 @@ -286,8 +284,7 @@ def test_trial_result_str(): def test_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState( + final_simulator_state = ccq.mps_simulator.MPSState( qubits=(q0,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(), @@ -295,7 +292,7 @@ def test_trial_result_repr_pretty(): result = ccq.mps_simulator.MPSTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) cirq.testing.assert_repr_pretty( result, diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 55a8a17c6fd..66069f7ea4f 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -16,6 +16,7 @@ import numpy as np +from cirq._compat import proper_repr from cirq.sim.clifford import stabilizer_state_ch_form from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs @@ -65,3 +66,11 @@ def __init__( super().__init__( state=initial_state, prng=prng, qubits=qubits, classical_data=classical_data ) + + def __repr__(self) -> str: + return ( + 'cirq.ActOnStabilizerCHFormArgs(' + f'initial_state={proper_repr(self.state)},' + f' qubits={self.qubits!r},' + f' classical_data={self.classical_data!r})' + ) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index b12137a1f54..1919222d618 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -36,7 +36,7 @@ import cirq from cirq import protocols, value from cirq.protocols import act_on -from cirq.sim import clifford, simulator_base +from cirq.sim import clifford, simulator, simulator_base class CliffordSimulator( @@ -107,25 +107,26 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'CliffordSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]', ): return CliffordTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) class CliffordTrialResult( simulator_base.SimulationTrialResultBase['clifford.ActOnStabilizerCHFormArgs'] ): + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.CliffordSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]', ) -> None: super().__init__( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) @property @@ -137,7 +138,7 @@ def final_state(self) -> 'cirq.CliffordState': def __str__(self) -> str: samples = super().__str__() - final = self._final_simulator_state + final = self._get_merged_sim_state().state return f'measurements: {samples}\noutput state: {final}' def _repr_pretty_(self, p: Any, cycle: bool): diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index 1c7c6a5a17d..c3bfca63d90 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -1,6 +1,5 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice import itertools -from unittest import mock import numpy as np import pytest @@ -212,32 +211,33 @@ def test_clifford_state_initial_state(): def test_clifford_trial_result_repr(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult) - final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0}) + final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) assert ( repr( cirq.CliffordTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) == "cirq.SimulationTrialResult(params=cirq.ParamResolver({}), " "measurements={'m': array([[1]])}, " - "final_simulator_state=StabilizerStateChForm(num_qubits=1))" + "final_simulator_state=cirq.ActOnStabilizerCHFormArgs(" + "initial_state=StabilizerStateChForm(num_qubits=1), " + "qubits=(cirq.LineQubit(0),), " + "classical_data=cirq.ClassicalDataDictionaryStore()))" ) def test_clifford_trial_result_str(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult) - final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0}) + final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) assert ( str( cirq.CliffordTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) == "measurements: m=1\n" @@ -247,12 +247,11 @@ def test_clifford_trial_result_str(): def test_clifford_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult) - final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0}) + final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) result = cirq.CliffordTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) cirq.testing.assert_repr_pretty(result, "measurements: m=1\n" "output state: |0⟩") diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index caf140f7c01..1a3188e654d 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -17,7 +17,7 @@ import numpy as np from cirq import ops, protocols, study, value -from cirq._compat import deprecated_parameter, proper_repr +from cirq._compat import deprecated_class, deprecated_parameter, proper_repr from cirq.sim import simulator, act_on_density_matrix_args, simulator_base if TYPE_CHECKING: @@ -186,10 +186,10 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.DensityMatrixStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', ) -> 'cirq.DensityMatrixTrialResult': return DensityMatrixTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) # TODO(#4209): Deduplicate with identical code in sparse_simulator. @@ -310,6 +310,7 @@ def __repr__(self) -> str: ) +@deprecated_class(deadline='v0.16', fix='This class is no longer used.') @value.value_equality(unhashable=True) class DensityMatrixSimulatorState: """The simulator state for DensityMatrixSimulator @@ -380,14 +381,15 @@ class DensityMatrixTrialResult( trial finishes. """ + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.DensityMatrixStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', ) -> None: super().__init__( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) self._final_density_matrix: Optional[np.ndarray] = None @@ -418,7 +420,7 @@ def __repr__(self) -> str: return ( 'cirq.DensityMatrixTrialResult(' f'params={self.params!r}, measurements={proper_repr(self.measurements)}, ' - f'final_step_result={self._final_step_result!r})' + f'final_simulator_state={self._final_simulator_state!r})' ) def _repr_pretty_(self, p: Any, cycle: bool): diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 8e8ca9db3e4..fc805a70554 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -921,98 +921,114 @@ def test_simulate_expectation_values_qubit_order(dtype): assert cirq.approx_eq(result_flipped[0], 3, atol=1e-6) -def test_density_matrix_simulator_state_eq(): - q0, q1 = cirq.LineQubit.range(2) - eq = cirq.testing.EqualsTester() - eq.add_equality_group( - cirq.DensityMatrixSimulatorState(density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}), - cirq.DensityMatrixSimulatorState(density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}), - ) - eq.add_equality_group( - cirq.DensityMatrixSimulatorState(density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0}) - ) - eq.add_equality_group( - cirq.DensityMatrixSimulatorState(density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0, q1: 1}) - ) +def test_density_matrix_simulator_state_eq_deprecated(): + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16', count=4): + q0, q1 = cirq.LineQubit.range(2) + eq = cirq.testing.EqualsTester() + eq.add_equality_group( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + ), + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + ), + ) + eq.add_equality_group( + cirq.DensityMatrixSimulatorState(density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0}) + ) + eq.add_equality_group( + cirq.DensityMatrixSimulatorState( + density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0, q1: 1} + ) + ) def test_density_matrix_simulator_state_qid_shape(): - q0, q1 = cirq.LineQubit.range(2) - assert cirq.qid_shape( - cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1} - ) - ) == (2, 2) - q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) - assert cirq.qid_shape( - cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1} - ) - ) == (3, 4) + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16', count=2): + q0, q1 = cirq.LineQubit.range(2) + assert cirq.qid_shape( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1} + ) + ) == (2, 2) + q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) + assert cirq.qid_shape( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1} + ) + ) == (3, 4) def test_density_matrix_simulator_state_repr(): - q0 = cirq.LineQubit(0) - assert ( - repr( - cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16'): + q0 = cirq.LineQubit(0) + assert ( + repr( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + ) ) + == "cirq.DensityMatrixSimulatorState(density_matrix=" + "np.array([[0.5, 0.5], [0.5, 0.5]]), " + "qubit_map={cirq.LineQubit(0): 0})" ) - == "cirq.DensityMatrixSimulatorState(density_matrix=" - "np.array([[0.5, 0.5], [0.5, 0.5]]), " - "qubit_map={cirq.LineQubit(0): 0})" - ) def test_density_matrix_trial_result_eq(): q0 = cirq.LineQubit(0) - final_step_result = cirq.DensityMatrixStepResult( - cirq.ActOnDensityMatrixArgs(initial_state=np.ones((2, 2)) * 0.5, qubits=[q0]) + final_simulator_state = cirq.ActOnDensityMatrixArgs( + initial_state=np.ones((2, 2)) * 0.5, qubits=[q0] ) eq = cirq.testing.EqualsTester() eq.add_equality_group( cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), ) eq.add_equality_group( cirq.DensityMatrixTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) eq.add_equality_group( cirq.DensityMatrixTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) def test_density_matrix_trial_result_qid_shape(): q0, q1 = cirq.LineQubit.range(2) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1} + final_simulator_state = cirq.ActOnDensityMatrixArgs( + initial_state=np.ones((4, 4)) / 4, qubits=[q0, q1] ) assert cirq.qid_shape( cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ) ) == (2, 2) q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) - final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1} + final_simulator_state = cirq.ActOnDensityMatrixArgs( + initial_state=np.ones((12, 12)) / 12, qubits=[q0, q1] ) assert cirq.qid_shape( cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ) ) == (3, 4) @@ -1020,7 +1036,7 @@ def test_density_matrix_trial_result_qid_shape(): def test_density_matrix_trial_result_repr(): q0 = cirq.LineQubit(0) dtype = np.complex64 - args = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.ActOnDensityMatrixArgs( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1028,23 +1044,20 @@ def test_density_matrix_trial_result_repr(): initial_state=np.ones((2, 2), dtype=dtype) * 0.5, dtype=dtype, ) - final_step_result = cirq.DensityMatrixStepResult(args) trial_result = cirq.DensityMatrixTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]], dtype=np.int32)}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) expected_repr = ( "cirq.DensityMatrixTrialResult(" "params=cirq.ParamResolver({'s': 1}), " "measurements={'m': np.array([[1]], dtype=np.int32)}, " - "final_step_result=cirq.DensityMatrixStepResult(" - "sim_state=cirq.ActOnDensityMatrixArgs(" + "final_simulator_state=cirq.ActOnDensityMatrixArgs(" "initial_state=np.array([[(0.5+0j), (0.5+0j)], [(0.5+0j), (0.5+0j)]], dtype=np.complex64), " "qid_shape=(2,), " "qubits=(cirq.LineQubit(0),), " - "classical_data=cirq.ClassicalDataDictionaryStore()), " - "dtype=np.complex64))" + "classical_data=cirq.ClassicalDataDictionaryStore()))" ) assert repr(trial_result) == expected_repr assert eval(expected_repr) == trial_result @@ -1111,7 +1124,7 @@ def test_works_on_pauli_string(): def test_density_matrix_trial_result_str(): q0 = cirq.LineQubit(0) dtype = np.complex64 - args = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.ActOnDensityMatrixArgs( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1119,9 +1132,8 @@ def test_density_matrix_trial_result_str(): initial_state=np.ones((2, 2), dtype=dtype) * 0.5, dtype=dtype, ) - final_step_result = cirq.DensityMatrixStepResult(args) result = cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ) # numpy varies whitespace in its representation for different versions @@ -1137,7 +1149,7 @@ def test_density_matrix_trial_result_str(): def test_density_matrix_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) dtype = np.complex64 - args = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.ActOnDensityMatrixArgs( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1145,9 +1157,8 @@ def test_density_matrix_trial_result_repr_pretty(): initial_state=np.ones((2, 2), dtype=dtype) * 0.5, dtype=dtype, ) - final_step_result = cirq.DensityMatrixStepResult(args) result = cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ) fake_printer = cirq.testing.FakePrinter() @@ -1523,7 +1534,7 @@ def test_large_untangled_okay(): # Validate a simulation run result = cirq.DensityMatrixSimulator().simulate(circuit) - assert set(result._final_step_result._qubits) == set(cirq.LineQubit.range(59)) + assert set(result._final_simulator_state.qubits) == set(cirq.LineQubit.range(59)) # _ = result.final_density_matrix hangs (as expected) # Validate a trial run and sampling diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index b1f3bed9483..371b3569003 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -29,6 +29,7 @@ import abc import collections +import inspect from typing import ( Any, Callable, @@ -37,7 +38,7 @@ Generic, Iterator, List, - Optional, + Mapping, Sequence, Set, Tuple, @@ -48,7 +49,7 @@ import numpy as np -from cirq import circuits, ops, protocols, study, value, work +from cirq import _compat, circuits, ops, protocols, study, value, work from cirq.sim.act_on_args import ActOnArgs from cirq.sim.operation_target import OperationTarget @@ -609,9 +610,21 @@ def simulate_sweep_iter( for step_result in all_step_results: for k, v in step_result.measurements.items(): measurements[k] = np.array(v, dtype=np.uint8) - yield self._create_simulator_trial_result( - params=param_resolver, measurements=measurements, final_step_result=step_result - ) + if ( + 'final_simulator_state' + in inspect.signature(self._create_simulator_trial_result).parameters + ): + yield self._create_simulator_trial_result( + params=param_resolver, + measurements=measurements, + final_simulator_state=step_result._simulator_state(), + ) + else: + yield self._create_simulator_trial_result( # pylint: disable=no-value-for-parameter, unexpected-keyword-arg, line-too-long + params=param_resolver, + measurements=measurements, + final_step_result=step_result, # type: ignore + ) def simulate_moment_steps( self, @@ -724,14 +737,14 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: TStepResult, + final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', ) -> TSimulationTrialResult: """This method can be implemented to create a trial result. Args: params: The ParamResolver for this trial. measurements: The measurement results for this trial. - final_step_result: The final step result of the simulation. + final_simulator_state: The final state of the simulation. Returns: The SimulationTrialResult. @@ -876,6 +889,55 @@ def sample_measurement_ops( ) +# When removing this, also remove the check in simulate_sweep_iter. +# Basically there should be no "final_step_result" anywhere in the project afterwards. +def _deprecated_step_result_parameter( + old_position: int = 4, new_position: int = 3 +) -> Callable[[Callable], Callable]: + assert old_position >= new_position + + def rewrite_deprecated_step_result_param(args, kwargs): + args = list(args) + state = ( + kwargs['final_simulator_state'] + if 'final_simulator_state' in kwargs + else args[new_position] + if len(args) > new_position and not isinstance(args[new_position], StepResult) + else None + ) + step_result = ( + kwargs['final_step_result'] + if 'final_step_result' in kwargs + else args[old_position] + if len(args) > old_position and isinstance(args[old_position], StepResult) + else None + ) + if (step_result is None) == (state is None): + raise ValueError( + 'Exactly one of final_simulator_state and final_step_result should be provided' + ) + if len(args) > old_position and isinstance(args[old_position], StepResult): + args[new_position] = args[old_position]._simulator_state() + if old_position > new_position: + del args[old_position] + elif 'final_step_result' in kwargs: + sim_state = kwargs['final_step_result']._simulator_state() + if len(args) > new_position: + args[new_position] = sim_state + else: + kwargs['final_simulator_state'] = sim_state + del kwargs['final_step_result'] + return tuple(args), kwargs + + return _compat.deprecated_parameter( + deadline='v0.16', + fix='', + parameter_desc='final_step_result', + match=lambda args, kwargs: 'final_step_result' in kwargs or len(args) > old_position, + rewrite=rewrite_deprecated_step_result_param, + ) + + @value.value_equality(unhashable=True) class SimulationTrialResult(Generic[TSimulatorState]): """Results of a simulation by a SimulatesFinalState. @@ -893,12 +955,12 @@ class SimulationTrialResult(Generic[TSimulatorState]): measurement gate.) """ + @_deprecated_step_result_parameter() def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: Optional[TSimulatorState] = None, - final_step_result: Optional['cirq.StepResult[TSimulatorState]'] = None, + final_simulator_state: TSimulatorState, ) -> None: """Initializes the `SimulationTrialResult` class. @@ -909,28 +971,10 @@ def __init__( boolean measurement results (ordered by the qubits acted on by the measurement gate.) final_simulator_state: The final simulator state. - final_step_result: The step result coming from the simulation, that - can be used to get the final simulator state. This is primarily - for cases when calculating simulator state may be expensive and - unneeded. If this is provided, then final_simulator_state - should not be, and vice versa. - - Raises: - ValueError: If `final_step_result` and `final_simulator_state` are both - None or both not None. """ - if [final_step_result, final_simulator_state].count(None) != 1: - raise ValueError( - 'Exactly one of final_simulator_state and final_step_result should be provided' - ) self.params = params self.measurements = measurements - self._final_step_result = final_step_result - self._final_simulator_state: TSimulatorState = ( - final_simulator_state - if final_simulator_state is not None - else cast('cirq.StepResult[TSimulatorState]', final_step_result)._simulator_state() - ) + self._final_simulator_state = final_simulator_state def __repr__(self) -> str: return ( @@ -962,7 +1006,7 @@ def _value_equality_values_(self) -> Any: return self.params, measurements, self._final_simulator_state @property - def qubit_map(self) -> Dict['cirq.Qid', int]: + def qubit_map(self) -> Mapping['cirq.Qid', int]: """A map from Qid to index used to define the ordering of the basis in the result. """ @@ -972,7 +1016,7 @@ def _qid_shape_(self) -> Tuple[int, ...]: return _qubit_map_to_shape(self.qubit_map) -def _qubit_map_to_shape(qubit_map: Dict['cirq.Qid', int]) -> Tuple[int, ...]: +def _qubit_map_to_shape(qubit_map: Mapping['cirq.Qid', int]) -> Tuple[int, ...]: qid_shape: List[int] = [-1] * len(qubit_map) try: for q, i in qubit_map.items(): diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index f5ab34fd2f4..c2d07b0fe46 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -36,6 +36,7 @@ from cirq import ops, protocols, study, value, devices from cirq.sim import ActOnArgsContainer from cirq.sim.operation_target import OperationTarget +from cirq.sim import simulator from cirq.sim.simulator import ( TSimulationTrialResult, TActOnArgs, @@ -385,11 +386,12 @@ class SimulationTrialResultBase( ): """A base class for trial results.""" + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: study.ParamResolver, measurements: Dict[str, np.ndarray], - final_step_result: StepResultBase[TActOnArgs], + final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', ) -> None: """Initializes the `SimulationTrialResultBase` class. @@ -399,10 +401,10 @@ def __init__( results. Measurement results are a numpy ndarray of actual boolean measurement results (ordered by the qubits acted on by the measurement gate.) - final_step_result: The step result coming from the simulation, that - can be used to get the final simulator state. + final_simulator_state: The final simulator state of the system after the + trial finishes. """ - super().__init__(params, measurements, final_step_result=final_step_result) + super().__init__(params, measurements, final_simulator_state=final_simulator_state) self._merged_sim_state_cache: Optional[TActOnArgs] = None def get_state_containing_qubit(self, qubit: 'cirq.Qid') -> TActOnArgs: diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 0fed8a667ad..47e3170fc24 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -124,9 +124,11 @@ def _create_simulator_trial_result( self, params: cirq.ParamResolver, measurements: Dict[str, np.ndarray], - final_step_result: CountingStepResult, + final_simulator_state: 'cirq.OperationTarget[CountingActOnArgs]', ) -> CountingTrialResult: - return CountingTrialResult(params, measurements, final_step_result=final_step_result) + return CountingTrialResult( + params, measurements, final_simulator_state=final_simulator_state + ) def _create_step_result( self, sim_state: cirq.OperationTarget[CountingActOnArgs] @@ -388,3 +390,20 @@ def _has_unitary_(self): simulator.simulate_sweep(program=circuit, params=params) assert op1.count == 2 assert op2.count == 2 + + +def test_deprecated_final_step_result(): + class OldCountingSimulator(CountingSimulator): + def _create_simulator_trial_result( # type: ignore + self, + params: cirq.ParamResolver, + measurements: Dict[str, np.ndarray], + final_step_result: CountingStepResult, + ) -> CountingTrialResult: + return CountingTrialResult(params, measurements, final_step_result=final_step_result) + + sim = OldCountingSimulator() + with cirq.testing.assert_deprecated('final_step_result', deadline='0.16'): + r = sim.simulate(cirq.Circuit()) + assert r._final_simulator_state.gate_count == 0 + assert r._final_simulator_state.measurement_count == 0 diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index f7063167471..26afb2a8e28 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -73,20 +73,20 @@ def _create_simulator_trial_result( self, params: study.ParamResolver, measurements: Dict[str, np.ndarray], - final_step_result: TStepResult, + final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', ) -> 'SimulationTrialResult': """This method creates a default trial result. Args: params: The ParamResolver for this trial. measurements: The measurement results for this trial. - final_step_result: The final step result of the simulation. + final_simulator_state: The final state of the simulation. Returns: The SimulationTrialResult. """ return SimulationTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) @@ -171,17 +171,16 @@ def steps(*args, **kwargs): program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2 ) - final_step_result = FakeStepResult(final_state=final_state) expected_results = [ cirq.SimulationTrialResult( measurements={'a': np.array([True, True])}, params=param_resolvers[0], - final_step_result=final_step_result, + final_simulator_state=final_state, ), cirq.SimulationTrialResult( measurements={'a': np.array([True, True])}, params=param_resolvers[1], - final_step_result=final_step_result, + final_simulator_state=final_state, ), ] assert results == expected_results @@ -242,46 +241,41 @@ def test_step_sample_measurement_ops_repeated_qubit(): def test_simulation_trial_result_equality(): eq = cirq.testing.EqualsTester() - final_step_result = FakeStepResult(final_state=()) eq.add_equality_group( cirq.SimulationTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=() ), cirq.SimulationTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=() ), ) eq.add_equality_group( cirq.SimulationTrialResult( - params=cirq.ParamResolver({'s': 1}), - measurements={}, - final_step_result=final_step_result, + params=cirq.ParamResolver({'s': 1}), measurements={}, final_simulator_state=() ) ) eq.add_equality_group( cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(), ) ) - final_step_result._final_state = (0, 1) eq.add_equality_group( cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) def test_simulation_trial_result_repr(): - final_step_result = FakeStepResult(final_state=(0, 1)) assert repr( cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == ( "cirq.SimulationTrialResult(" @@ -292,13 +286,10 @@ def test_simulation_trial_result_repr(): def test_simulation_trial_result_str(): - final_step_result = FakeStepResult(final_state=(0, 1)) assert ( str( cirq.SimulationTrialResult( - params=cirq.ParamResolver({'s': 1}), - measurements={}, - final_step_result=final_step_result, + params=cirq.ParamResolver({'s': 1}), measurements={}, final_simulator_state=(0, 1) ) ) == '(no measurements)' @@ -309,7 +300,7 @@ def test_simulation_trial_result_str(): cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == 'm=1' @@ -320,7 +311,7 @@ def test_simulation_trial_result_str(): cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1, 2, 3])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == 'm=123' @@ -331,7 +322,7 @@ def test_simulation_trial_result_str(): cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([9, 10, 11])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == 'm=9 10 11' @@ -447,9 +438,7 @@ def _kraus_(self): def test_iter_definitions(): - dummy_trial_result = SimulationTrialResult( - params={}, measurements={}, final_step_result=FakeStepResult(final_state=[]) - ) + dummy_trial_result = SimulationTrialResult(params={}, measurements={}, final_simulator_state=[]) class FakeNonIterSimulatorImpl( SimulatesAmplitudes, SimulatesExpectationValues, SimulatesFinalState @@ -539,7 +528,31 @@ class FakeMissingIterSimulatorImpl( def test_trial_result_initializer(): + resolver = cirq.ParamResolver() + step = mock.Mock(cirq.StepResultBase) + step._simulator_state.return_value = 1 + state = 3 + with pytest.raises(ValueError, match='Exactly one of'): + _ = SimulationTrialResult(resolver, {}, None, None) + with pytest.raises(ValueError, match='Exactly one of'): + _ = SimulationTrialResult(resolver, {}, state, step) with pytest.raises(ValueError, match='Exactly one of'): - _ = SimulationTrialResult(cirq.ParamResolver(), {}, None, None) + _ = SimulationTrialResult(resolver, {}, final_simulator_state=None, final_step_result=None) with pytest.raises(ValueError, match='Exactly one of'): - _ = SimulationTrialResult(cirq.ParamResolver(), {}, object(), mock.Mock(TStepResult)) + _ = SimulationTrialResult(resolver, {}, final_simulator_state=state, final_step_result=step) + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, final_step_result=step) + assert x._final_simulator_state == 1 + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, None, final_step_result=step) + assert x._final_simulator_state == 1 + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, None, step) + assert x._final_simulator_state == 1 + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, final_simulator_state=None, final_step_result=step) + assert x._final_simulator_state == 1 + x = SimulationTrialResult(resolver, {}, state) + assert x._final_simulator_state == 3 + x = SimulationTrialResult(resolver, {}, final_simulator_state=state) + assert x._final_simulator_state == 3 diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index 7364cbcd352..ebccd000ca0 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -14,7 +14,7 @@ """Helpers for handling quantum state vectors.""" import abc -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence +from typing import List, Mapping, Optional, Tuple, TYPE_CHECKING, Sequence import numpy as np @@ -31,7 +31,7 @@ class StateVectorMixin: """A mixin that provide methods for objects that have a state vector.""" - def __init__(self, qubit_map: Optional[Dict['cirq.Qid', int]] = None, *args, **kwargs): + def __init__(self, qubit_map: Optional[Mapping['cirq.Qid', int]] = None, *args, **kwargs): """Inits StateVectorMixin. Args: @@ -48,7 +48,7 @@ def __init__(self, qubit_map: Optional[Dict['cirq.Qid', int]] = None, *args, **k self._qid_shape = None if qubit_map is None else qid_shape @property - def qubit_map(self) -> Dict['cirq.Qid', int]: + def qubit_map(self) -> Mapping['cirq.Qid', int]: return self._qubit_map def _qid_shape_(self) -> Tuple[int, ...]: diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index a2a0e442998..a473565ba3e 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -30,7 +30,7 @@ import numpy as np from cirq import ops, value, qis -from cirq._compat import proper_repr +from cirq._compat import deprecated_class, proper_repr from cirq.sim import simulator, state_vector, simulator_base if TYPE_CHECKING: @@ -69,10 +69,10 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.StateVectorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', ) -> 'cirq.StateVectorTrialResult': return StateVectorTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) def compute_amplitudes_sweep_iter( @@ -107,6 +107,7 @@ class StateVectorStepResult( pass +@deprecated_class(deadline='v0.16', fix='This class is no longer used.') @value.value_equality(unhashable=True) class StateVectorSimulatorState: def __init__(self, state_vector: np.ndarray, qubit_map: Dict[ops.Qid, int]) -> None: @@ -143,13 +144,13 @@ def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.StateVectorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', ) -> None: super().__init__( params=params, measurements=measurements, - final_step_result=final_step_result, - qubit_map=final_step_result._qubit_mapping, + final_simulator_state=final_simulator_state, + qubit_map=final_simulator_state.qubit_map, ) self._final_state_vector: Optional[np.ndarray] = None @@ -222,5 +223,5 @@ def __repr__(self) -> str: return ( 'cirq.StateVectorTrialResult(' f'params={self.params!r}, measurements={proper_repr(self.measurements)}, ' - f'final_step_result={self._final_step_result!r})' + f'final_simulator_state={self._final_simulator_state!r})' ) diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index 50d3063f398..fdae8ece86e 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock - import numpy as np import cirq @@ -22,89 +20,87 @@ def test_state_vector_trial_result_repr(): q0 = cirq.NamedQubit('a') - args = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.ActOnStateVectorArgs( available_buffer=np.array([0, 1], dtype=np.complex64), prng=np.random.RandomState(0), qubits=[q0], initial_state=np.array([0, 1], dtype=np.complex64), dtype=np.complex64, ) - final_step_result = cirq.SparseSimulatorStep(args) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]], dtype=np.int32)}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) expected_repr = ( "cirq.StateVectorTrialResult(" "params=cirq.ParamResolver({'s': 1}), " "measurements={'m': np.array([[1]], dtype=np.int32)}, " - "final_step_result=cirq.SparseSimulatorStep(" - "sim_state=cirq.ActOnStateVectorArgs(" + "final_simulator_state=cirq.ActOnStateVectorArgs(" "initial_state=np.array([0j, (1+0j)], dtype=np.complex64), " "qubits=(cirq.NamedQubit('a'),), " - "classical_data=cirq.ClassicalDataDictionaryStore()), " - "dtype=np.complex64))" + "classical_data=cirq.ClassicalDataDictionaryStore()))" ) assert repr(trial_result) == expected_repr assert eval(expected_repr) == trial_result def test_state_vector_simulator_state_repr(): - final_simulator_state = cirq.StateVectorSimulatorState( - qubit_map={cirq.NamedQubit('a'): 0}, state_vector=np.array([0, 1]) - ) - cirq.testing.assert_equivalent_repr(final_simulator_state) + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16', count=4): + final_simulator_state = cirq.StateVectorSimulatorState( + qubit_map={cirq.NamedQubit('a'): 0}, state_vector=np.array([0, 1]) + ) + cirq.testing.assert_equivalent_repr(final_simulator_state) def test_state_vector_trial_result_equality(): eq = cirq.testing.EqualsTester() - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(initial_state=np.array([])) - ) + final_simulator_state = cirq.ActOnStateVectorArgs(initial_state=np.array([])) eq.add_equality_group( cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), ) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(initial_state=np.array([1])) - ) + final_simulator_state = cirq.ActOnStateVectorArgs(initial_state=np.array([1])) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) def test_state_vector_trial_result_state_mixin(): qubits = cirq.LineQubit.range(2) - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(qubits=qubits, initial_state=np.array([0, 1, 0, 0])) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=qubits, initial_state=np.array([0, 1, 0, 0]) ) result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'a': 2}), measurements={'m': np.array([1, 2])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) rho = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) np.testing.assert_array_almost_equal(rho, result.density_matrix_of(qubits)) @@ -114,70 +110,60 @@ def test_state_vector_trial_result_state_mixin(): def test_state_vector_trial_result_qid_shape(): - qubit_map = {cirq.NamedQubit('a'): 0} - final_step_result = mock.Mock(cirq.StateVectorStepResult) - final_step_result._qubit_mapping = qubit_map - final_step_result._simulator_state.return_value = cirq.StateVectorSimulatorState( - qubit_map=qubit_map, state_vector=np.array([0, 1]) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=[cirq.NamedQubit('a')], initial_state=np.array([0, 1]) ) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) - assert cirq.qid_shape(final_step_result._simulator_state()) == (2,) assert cirq.qid_shape(trial_result) == (2,) - q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) - qubit_map = {q0: 1, q1: 0} - final_step_result._qubit_mapping = qubit_map - final_step_result._simulator_state.return_value = cirq.StateVectorSimulatorState( - qubit_map=qubit_map, state_vector=np.array([0, 0, 0, 0, 1, 0]) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=cirq.LineQid.for_qid_shape((3, 2)), initial_state=np.array([0, 0, 0, 0, 1, 0]) ) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[2, 0]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) - assert cirq.qid_shape(final_step_result._simulator_state()) == (3, 2) assert cirq.qid_shape(trial_result) == (3, 2) def test_state_vector_trial_state_vector_is_copy(): final_state_vector = np.array([0, 1], dtype=np.complex64) qubit_map = {cirq.NamedQubit('a'): 0} - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(qubits=list(qubit_map), initial_state=final_state_vector) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=list(qubit_map), initial_state=final_state_vector ) trial_result = cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ) - assert trial_result.state_vector() is not final_step_result._simulator_state().target_tensor + assert trial_result.state_vector() is not final_simulator_state.target_tensor def test_str_big(): qs = cirq.LineQubit.range(10) - args = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.ActOnStateVectorArgs( prng=np.random.RandomState(0), qubits=qs, initial_state=np.array([1] * 2**10, dtype=np.complex64) * 0.03125, dtype=np.complex64, ) - final_step_result = cirq.SparseSimulatorStep(args) - result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_step_result) + result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state) assert 'output vector: [0.03125+0.j 0.03125+0.j 0.03125+0.j ..' in str(result) def test_pretty_print(): - args = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.ActOnStateVectorArgs( available_buffer=np.array([1]), prng=np.random.RandomState(0), qubits=[], initial_state=np.array([1], dtype=np.complex64), dtype=np.complex64, ) - final_step_result = cirq.SparseSimulatorStep(args) - result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_step_result) + result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state) # Test Jupyter console output from class FakePrinter: