Skip to content

Deprecate the final_step_result parameter of TrialResult #5281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from cirq import devices, protocols, qis, value
from cirq._compat import deprecated_parameter
from cirq.sim import simulator_base
from cirq.sim import simulator, simulator_base
from cirq.sim.act_on_args import ActOnArgs

if TYPE_CHECKING:
Expand Down Expand Up @@ -122,35 +122,36 @@ def _create_simulator_trial_result(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'MPSSimulatorStepResult',
final_simulator_state: 'cirq.OperationTarget[MPSState]',
) -> 'MPSTrialResult':
"""Creates a single trial results with the measurements.

Args:
params: A ParamResolver for determining values of Symbols.
measurements: A dictionary from measurement key (e.g. qubit) to the
actual measurement array.
final_step_result: The final step result of the simulation.
final_simulator_state: The final state of the simulation.

Returns:
A single result.
"""
return MPSTrialResult(
params=params, measurements=measurements, final_step_result=final_step_result
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)


class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState']):
"""A single trial reult"""

@simulator._deprecated_step_result_parameter(old_position=3)
def __init__(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'MPSSimulatorStepResult',
final_simulator_state: 'cirq.OperationTarget[MPSState]',
) -> None:
super().__init__(
params=params, measurements=measurements, final_step_result=final_step_result
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)

@property
Expand Down
11 changes: 4 additions & 7 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
import itertools
import math
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -263,8 +262,7 @@ def test_measurement_str():

def test_trial_result_str():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.StepResult)
final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState(
final_simulator_state = ccq.mps_simulator.MPSState(
qubits=(q0,),
prng=value.parse_random_state(0),
simulation_options=ccq.mps_simulator.MPSOptions(),
Expand All @@ -274,7 +272,7 @@ def test_trial_result_str():
ccq.mps_simulator.MPSTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
final_simulator_state=final_simulator_state,
)
)
== """measurements: m=1
Expand All @@ -286,16 +284,15 @@ def test_trial_result_str():

def test_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.StepResult)
final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState(
final_simulator_state = ccq.mps_simulator.MPSState(
qubits=(q0,),
prng=value.parse_random_state(0),
simulation_options=ccq.mps_simulator.MPSOptions(),
)
result = ccq.mps_simulator.MPSTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
final_simulator_state=final_simulator_state,
)
cirq.testing.assert_repr_pretty(
result,
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np

from cirq._compat import proper_repr
from cirq.sim.clifford import stabilizer_state_ch_form
from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs

Expand Down Expand Up @@ -65,3 +66,11 @@ def __init__(
super().__init__(
state=initial_state, prng=prng, qubits=qubits, classical_data=classical_data
)

def __repr__(self) -> str:
return (
'cirq.ActOnStabilizerCHFormArgs('
f'initial_state={proper_repr(self.state)},'
f' qubits={self.qubits!r},'
f' classical_data={self.classical_data!r})'
)
13 changes: 7 additions & 6 deletions cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import cirq
from cirq import protocols, value
from cirq.protocols import act_on
from cirq.sim import clifford, simulator_base
from cirq.sim import clifford, simulator, simulator_base


class CliffordSimulator(
Expand Down Expand Up @@ -107,25 +107,26 @@ def _create_simulator_trial_result(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'CliffordSimulatorStepResult',
final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]',
):

return CliffordTrialResult(
params=params, measurements=measurements, final_step_result=final_step_result
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)


class CliffordTrialResult(
simulator_base.SimulationTrialResultBase['clifford.ActOnStabilizerCHFormArgs']
):
@simulator._deprecated_step_result_parameter(old_position=3)
def __init__(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'cirq.CliffordSimulatorStepResult',
final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]',
) -> None:
super().__init__(
params=params, measurements=measurements, final_step_result=final_step_result
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)

@property
Expand All @@ -137,7 +138,7 @@ def final_state(self) -> 'cirq.CliffordState':

def __str__(self) -> str:
samples = super().__str__()
final = self._final_simulator_state
final = self._get_merged_sim_state().state
return f'measurements: {samples}\noutput state: {final}'

def _repr_pretty_(self, p: Any, cycle: bool):
Expand Down
21 changes: 10 additions & 11 deletions cirq-core/cirq/sim/clifford/clifford_simulator_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
import itertools
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -212,32 +211,33 @@ def test_clifford_state_initial_state():

def test_clifford_trial_result_repr():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult)
final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0})
final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0])
assert (
repr(
cirq.CliffordTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
final_simulator_state=final_simulator_state,
)
)
== "cirq.SimulationTrialResult(params=cirq.ParamResolver({}), "
"measurements={'m': array([[1]])}, "
"final_simulator_state=StabilizerStateChForm(num_qubits=1))"
"final_simulator_state=cirq.ActOnStabilizerCHFormArgs("
"initial_state=StabilizerStateChForm(num_qubits=1), "
"qubits=(cirq.LineQubit(0),), "
"classical_data=cirq.ClassicalDataDictionaryStore()))"
)


def test_clifford_trial_result_str():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult)
final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0})
final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0])
assert (
str(
cirq.CliffordTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
final_simulator_state=final_simulator_state,
)
)
== "measurements: m=1\n"
Expand All @@ -247,12 +247,11 @@ def test_clifford_trial_result_str():

def test_clifford_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult)
final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0})
final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0])
result = cirq.CliffordTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
final_simulator_state=final_simulator_state,
)

cirq.testing.assert_repr_pretty(result, "measurements: m=1\n" "output state: |0⟩")
Expand Down
14 changes: 8 additions & 6 deletions cirq-core/cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from cirq import ops, protocols, study, value
from cirq._compat import deprecated_parameter, proper_repr
from cirq._compat import deprecated_class, deprecated_parameter, proper_repr
from cirq.sim import simulator, act_on_density_matrix_args, simulator_base

if TYPE_CHECKING:
Expand Down Expand Up @@ -186,10 +186,10 @@ def _create_simulator_trial_result(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'cirq.DensityMatrixStepResult',
final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]',
) -> 'cirq.DensityMatrixTrialResult':
return DensityMatrixTrialResult(
params=params, measurements=measurements, final_step_result=final_step_result
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)

# TODO(#4209): Deduplicate with identical code in sparse_simulator.
Expand Down Expand Up @@ -310,6 +310,7 @@ def __repr__(self) -> str:
)


@deprecated_class(deadline='v0.16', fix='This class is no longer used.')
@value.value_equality(unhashable=True)
class DensityMatrixSimulatorState:
"""The simulator state for DensityMatrixSimulator
Expand Down Expand Up @@ -380,14 +381,15 @@ class DensityMatrixTrialResult(
trial finishes.
"""

@simulator._deprecated_step_result_parameter(old_position=3)
def __init__(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'cirq.DensityMatrixStepResult',
final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]',
) -> None:
super().__init__(
params=params, measurements=measurements, final_step_result=final_step_result
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)
self._final_density_matrix: Optional[np.ndarray] = None

Expand Down Expand Up @@ -418,7 +420,7 @@ def __repr__(self) -> str:
return (
'cirq.DensityMatrixTrialResult('
f'params={self.params!r}, measurements={proper_repr(self.measurements)}, '
f'final_step_result={self._final_step_result!r})'
f'final_simulator_state={self._final_simulator_state!r})'
)

def _repr_pretty_(self, p: Any, cycle: bool):
Expand Down
Loading