Skip to content

Bump types so that SimulatesIntermediateState isn't bound to ActOnArgs #5283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,6 @@ def strat_act_on_from_apply_decompose(
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
protocols.act_on(operation, args)
return True


TActOnArgs = TypeVar('TActOnArgs', bound=ActOnArgs)
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np

from cirq import ops, protocols, value
from cirq.sim.act_on_args import TActOnArgs
from cirq.sim.operation_target import OperationTarget
from cirq.sim.simulator import TActOnArgs

if TYPE_CHECKING:
import cirq
Expand Down
22 changes: 10 additions & 12 deletions cirq-core/cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import numpy as np

from cirq import _compat, circuits, ops, protocols, study, value, work
from cirq.sim.act_on_args import ActOnArgs
from cirq.sim.operation_target import OperationTarget

if TYPE_CHECKING:
Expand All @@ -60,7 +59,6 @@
TStepResult = TypeVar('TStepResult', bound='StepResult')
TSimulationTrialResult = TypeVar('TSimulationTrialResult', bound='SimulationTrialResult')
TSimulatorState = TypeVar('TSimulatorState', bound=Any)
TActOnArgs = TypeVar('TActOnArgs', bound=ActOnArgs)


class SimulatesSamples(work.Sampler, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -551,7 +549,7 @@ def simulate_sweep_iter(


class SimulatesIntermediateState(
Generic[TStepResult, TSimulationTrialResult, TActOnArgs],
Generic[TStepResult, TSimulationTrialResult, TSimulatorState],
SimulatesFinalState[TSimulationTrialResult],
metaclass=abc.ABCMeta,
):
Expand Down Expand Up @@ -692,8 +690,8 @@ def _base_iterator(
@abc.abstractmethod
def _create_act_on_args(
self, initial_state: Any, qubits: Sequence['cirq.Qid']
) -> 'cirq.OperationTarget[TActOnArgs]':
"""Creates the OperationTarget state for a simulator.
) -> TSimulatorState:
"""Creates the state for a simulator.

Custom simulators should implement this method.

Expand All @@ -706,14 +704,14 @@ def _create_act_on_args(
ordering of the computational basis states.

Returns:
The `OperationTarget` for this simulator.
The `TSimulatorState` for this simulator.
"""

@abc.abstractmethod
def _core_iterator(
self,
circuit: 'cirq.AbstractCircuit',
sim_state: 'cirq.OperationTarget[TActOnArgs]',
sim_state: TSimulatorState,
all_measurements_are_terminal: bool = False,
) -> Iterator[TStepResult]:
"""Iterator over StepResult from Moments of a Circuit.
Expand All @@ -722,7 +720,7 @@ def _core_iterator(

Args:
circuit: The circuit to simulate.
sim_state: The initial args for the simulation. The form of
sim_state: The initial state for the simulation. The form of
this state depends on the simulation implementation. See
documentation of the implementing class for details.
all_measurements_are_terminal: Whether all measurements in
Expand All @@ -737,7 +735,7 @@ def _create_simulator_trial_result(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_simulator_state: 'cirq.OperationTarget[TActOnArgs]',
final_simulator_state: TSimulatorState,
) -> TSimulationTrialResult:
"""This method can be implemented to create a trial result.

Expand All @@ -760,11 +758,10 @@ class StepResult(Generic[TSimulatorState], metaclass=abc.ABCMeta):
results, ordered by the qubits that the measurement operates on.
"""

def __init__(self, sim_state: 'cirq.OperationTarget') -> None:
def __init__(self, sim_state: TSimulatorState) -> None:
self._sim_state = sim_state
self.measurements = sim_state.log_of_measurement_results
self._classical_data = sim_state.classical_data

@abc.abstractmethod
def _simulator_state(self) -> TSimulatorState:
"""Returns the simulator state of the simulator after this step.

Expand All @@ -775,6 +772,7 @@ def _simulator_state(self) -> TSimulatorState:
simulation,see documentation for the implementing class for the form of
details.
"""
return self._sim_state

@abc.abstractmethod
def sample(
Expand Down
15 changes: 7 additions & 8 deletions cirq-core/cirq/sim/simulator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@

from cirq import ops, protocols, study, value, devices
from cirq.sim import ActOnArgsContainer
from cirq.sim.operation_target import OperationTarget
from cirq.sim import simulator
from cirq.sim.act_on_args import TActOnArgs
from cirq.sim.operation_target import OperationTarget
from cirq.sim.simulator import (
TSimulationTrialResult,
TActOnArgs,
SimulatesIntermediateState,
SimulatesSamples,
StepResult,
Expand All @@ -57,7 +57,9 @@

class SimulatorBase(
Generic[TStepResultBase, TSimulationTrialResult, TActOnArgs],
SimulatesIntermediateState[TStepResultBase, TSimulationTrialResult, TActOnArgs],
SimulatesIntermediateState[
TStepResultBase, TSimulationTrialResult, OperationTarget[TActOnArgs]
],
SimulatesSamples,
metaclass=abc.ABCMeta,
):
Expand Down Expand Up @@ -352,13 +354,13 @@ def __init__(self, sim_state: OperationTarget[TActOnArgs]):
Args:
sim_state: The `OperationTarget` for this step.
"""
self._sim_state = sim_state
self._merged_sim_state_cache: Optional[TActOnArgs] = None
super().__init__(sim_state)
self._merged_sim_state_cache: Optional[TActOnArgs] = None
qubits = sim_state.qubits
self._qubits = qubits
self._qubit_mapping = {q: i for i, q in enumerate(qubits)}
self._qubit_shape = tuple(q.dimension for q in qubits)
self._classical_data = sim_state.classical_data

def _qid_shape_(self):
return self._qubit_shape
Expand All @@ -377,9 +379,6 @@ def sample(
) -> np.ndarray:
return self._sim_state.sample(qubits, repetitions, seed)

def _simulator_state(self) -> 'cirq.OperationTarget[TActOnArgs]':
return self._sim_state


class SimulationTrialResultBase(
SimulationTrialResult[OperationTarget[TActOnArgs]], Generic[TActOnArgs], abc.ABC
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import cirq
from cirq import study
from cirq.sim.act_on_args import TActOnArgs
from cirq.sim.simulator import (
TStepResult,
SimulatesAmplitudes,
Expand All @@ -30,7 +31,6 @@
SimulatesIntermediateState,
SimulatesSamples,
SimulationTrialResult,
TActOnArgs,
)


Expand Down