Skip to content

Cleanup classical simulator code and fix a couple of bugs #6344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 41 additions & 61 deletions cirq-core/cirq/sim/classical_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
import numpy as np


def _is_identity(op: ops.Operation) -> bool:
if isinstance(op.gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)):
return op.gate.exponent % 2 == 0
return False


class ClassicalStateSimulator(SimulatesSamples):
"""A simulator that only accepts only gates with classical counterparts.
"""A simulator that accepts only gates with classical counterparts.

This simulator evolves a single state, using only gates that output a single state for each
input state. The simulator runs in linear time, at the cost of not supporting superposition.
Expand All @@ -47,7 +53,9 @@ class ClassicalStateSimulator(SimulatesSamples):
A dictionary mapping measurement keys to measurement results.

Raises:
ValueError: If one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement.
ValueError: If
- one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement.
- A measurement key is used for measurements on different numbers of qubits.
"""

def _run(
Expand All @@ -60,68 +68,40 @@ def _run(

for moment in resolved_circuit:
for op in moment:
gate = op.gate
if gate == ops.X:
values_dict[op.qubits[0]] = 1 - values_dict[op.qubits[0]]

elif (
isinstance(gate, ops.CNotPowGate)
and gate.exponent == 1
and gate.global_shift == 0
):
if values_dict[op.qubits[0]] == 1:
values_dict[op.qubits[1]] = 1 - values_dict[op.qubits[1]]

elif (
isinstance(gate, ops.SwapPowGate)
and gate.exponent == 1
and gate.global_shift == 0
):
hold_qubit = values_dict[op.qubits[1]]
values_dict[op.qubits[1]] = values_dict[op.qubits[0]]
values_dict[op.qubits[0]] = hold_qubit

elif (
isinstance(gate, ops.CCXPowGate)
and gate.exponent == 1
and gate.global_shift == 0
):
if (values_dict[op.qubits[0]] == 1) and (values_dict[op.qubits[1]] == 1):
values_dict[op.qubits[2]] = 1 - values_dict[op.qubits[2]]

elif isinstance(gate, ops.MeasurementGate):
qubits_in_order = op.qubits
# add the new instance of a key to the numpy array in results dictionary
if gate.key in results_dict:
shape = len(qubits_in_order)
current_array = results_dict[gate.key]
new_instance = np.zeros(shape, dtype=np.uint8)
for bits in range(0, len(qubits_in_order)):
new_instance[bits] = values_dict[qubits_in_order[bits]]
results_dict[gate.key] = np.insert(
current_array, len(current_array[0]), new_instance, axis=1
if _is_identity(op):
continue
if op.gate == ops.X:
(q,) = op.qubits
values_dict[q] ^= 1
elif op.gate == ops.CNOT:
c, q = op.qubits
values_dict[q] ^= values_dict[c]
elif op.gate == ops.SWAP:
a, b = op.qubits
values_dict[a], values_dict[b] = values_dict[b], values_dict[a]
elif op.gate == ops.TOFFOLI:
c1, c2, q = op.qubits
values_dict[q] ^= values_dict[c1] & values_dict[c2]
elif protocols.is_measurement(op):
measurement_values = np.array(
[[[values_dict[q] for q in op.qubits]]] * repetitions, dtype=np.uint8
)
key = op.gate.key # type: ignore
if key in results_dict:
if op._num_qubits_() != results_dict[key].shape[-1]:
raise ValueError(
f'Measurement shape {len(measurement_values)} does not match '
f'{results_dict[key].shape[-1]} in {key}.'
)
results_dict[key] = np.concatenate(
(results_dict[key], measurement_values), axis=1
)
else:
# create the array for the results dictionary
new_array_shape = (repetitions, 1, len(qubits_in_order))
new_array = np.zeros(new_array_shape, dtype=np.uint8)
for reps in range(0, repetitions):
for instances in range(1):
for bits in range(0, len(qubits_in_order)):
new_array[reps][instances][bits] = values_dict[
qubits_in_order[bits]
]
results_dict[gate.key] = new_array

elif not (
(isinstance(gate, ops.XPowGate) and gate.exponent == 0)
or (isinstance(gate, ops.CCXPowGate) and gate.exponent == 0)
or (isinstance(gate, ops.SwapPowGate) and gate.exponent == 0)
or (isinstance(gate, ops.CNotPowGate) and gate.exponent == 0)
):
results_dict[key] = measurement_values
else:
raise ValueError(
"Can not simulate gates other than cirq.XGate, "
+ "cirq.CNOT, cirq.SWAP, and cirq.CCNOT"
f'{op} is not one of cirq.X, cirq.CNOT, cirq.SWAP, '
'cirq.CCNOT, or a measurement'
)

return results_dict
Loading