diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 2166453b2ff..6fc0b679833 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -386,3 +386,6 @@ def strat_act_on_from_apply_decompose( operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits]) protocols.act_on(operation, args) return True + + +TActOnArgs = TypeVar('TActOnArgs', bound=ActOnArgs) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 420293ed8bf..0b4e0856b64 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -18,8 +18,8 @@ import numpy as np from cirq import ops, protocols, value +from cirq.sim.act_on_args import TActOnArgs from cirq.sim.operation_target import OperationTarget -from cirq.sim.simulator import TActOnArgs if TYPE_CHECKING: import cirq diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 371b3569003..2d9a160474d 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -50,7 +50,6 @@ import numpy as np from cirq import _compat, circuits, ops, protocols, study, value, work -from cirq.sim.act_on_args import ActOnArgs from cirq.sim.operation_target import OperationTarget if TYPE_CHECKING: @@ -60,7 +59,6 @@ TStepResult = TypeVar('TStepResult', bound='StepResult') TSimulationTrialResult = TypeVar('TSimulationTrialResult', bound='SimulationTrialResult') TSimulatorState = TypeVar('TSimulatorState', bound=Any) -TActOnArgs = TypeVar('TActOnArgs', bound=ActOnArgs) class SimulatesSamples(work.Sampler, metaclass=abc.ABCMeta): @@ -551,7 +549,7 @@ def simulate_sweep_iter( class SimulatesIntermediateState( - Generic[TStepResult, TSimulationTrialResult, TActOnArgs], + Generic[TStepResult, TSimulationTrialResult, TSimulatorState], SimulatesFinalState[TSimulationTrialResult], metaclass=abc.ABCMeta, ): @@ -692,8 +690,8 @@ def _base_iterator( @abc.abstractmethod def _create_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'] - ) -> 'cirq.OperationTarget[TActOnArgs]': - """Creates the OperationTarget state for a simulator. + ) -> TSimulatorState: + """Creates the state for a simulator. Custom simulators should implement this method. @@ -706,14 +704,14 @@ def _create_act_on_args( ordering of the computational basis states. Returns: - The `OperationTarget` for this simulator. + The `TSimulatorState` for this simulator. """ @abc.abstractmethod def _core_iterator( self, circuit: 'cirq.AbstractCircuit', - sim_state: 'cirq.OperationTarget[TActOnArgs]', + sim_state: TSimulatorState, all_measurements_are_terminal: bool = False, ) -> Iterator[TStepResult]: """Iterator over StepResult from Moments of a Circuit. @@ -722,7 +720,7 @@ def _core_iterator( Args: circuit: The circuit to simulate. - sim_state: The initial args for the simulation. The form of + sim_state: The initial state for the simulation. The form of this state depends on the simulation implementation. See documentation of the implementing class for details. all_measurements_are_terminal: Whether all measurements in @@ -737,7 +735,7 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', + final_simulator_state: TSimulatorState, ) -> TSimulationTrialResult: """This method can be implemented to create a trial result. @@ -760,11 +758,10 @@ class StepResult(Generic[TSimulatorState], metaclass=abc.ABCMeta): results, ordered by the qubits that the measurement operates on. """ - def __init__(self, sim_state: 'cirq.OperationTarget') -> None: + def __init__(self, sim_state: TSimulatorState) -> None: + self._sim_state = sim_state self.measurements = sim_state.log_of_measurement_results - self._classical_data = sim_state.classical_data - @abc.abstractmethod def _simulator_state(self) -> TSimulatorState: """Returns the simulator state of the simulator after this step. @@ -775,6 +772,7 @@ def _simulator_state(self) -> TSimulatorState: simulation,see documentation for the implementing class for the form of details. """ + return self._sim_state @abc.abstractmethod def sample( diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index c2d07b0fe46..e147a708c1a 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -35,11 +35,11 @@ from cirq import ops, protocols, study, value, devices from cirq.sim import ActOnArgsContainer -from cirq.sim.operation_target import OperationTarget from cirq.sim import simulator +from cirq.sim.act_on_args import TActOnArgs +from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( TSimulationTrialResult, - TActOnArgs, SimulatesIntermediateState, SimulatesSamples, StepResult, @@ -57,7 +57,9 @@ class SimulatorBase( Generic[TStepResultBase, TSimulationTrialResult, TActOnArgs], - SimulatesIntermediateState[TStepResultBase, TSimulationTrialResult, TActOnArgs], + SimulatesIntermediateState[ + TStepResultBase, TSimulationTrialResult, OperationTarget[TActOnArgs] + ], SimulatesSamples, metaclass=abc.ABCMeta, ): @@ -352,13 +354,13 @@ def __init__(self, sim_state: OperationTarget[TActOnArgs]): Args: sim_state: The `OperationTarget` for this step. """ - self._sim_state = sim_state - self._merged_sim_state_cache: Optional[TActOnArgs] = None super().__init__(sim_state) + self._merged_sim_state_cache: Optional[TActOnArgs] = None qubits = sim_state.qubits self._qubits = qubits self._qubit_mapping = {q: i for i, q in enumerate(qubits)} self._qubit_shape = tuple(q.dimension for q in qubits) + self._classical_data = sim_state.classical_data def _qid_shape_(self): return self._qubit_shape @@ -377,9 +379,6 @@ def sample( ) -> np.ndarray: return self._sim_state.sample(qubits, repetitions, seed) - def _simulator_state(self) -> 'cirq.OperationTarget[TActOnArgs]': - return self._sim_state - class SimulationTrialResultBase( SimulationTrialResult[OperationTarget[TActOnArgs]], Generic[TActOnArgs], abc.ABC diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 26afb2a8e28..52b886196b8 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -22,6 +22,7 @@ import cirq from cirq import study +from cirq.sim.act_on_args import TActOnArgs from cirq.sim.simulator import ( TStepResult, SimulatesAmplitudes, @@ -30,7 +31,6 @@ SimulatesIntermediateState, SimulatesSamples, SimulationTrialResult, - TActOnArgs, )