Skip to content

Commit 237f64f

Browse files
mpharriganCirqBot
authored andcommitted
JSON Protocol (#1880)
Protocol to support serializing cirq objects - `cirq.to_json` - `cirq.read_json` The former looks for a magic method `_json_dict_(self)` which returns a dictionary keyed by strings and whose values are either basic python types or are objects that support the SupportsJSON protocol. They dictionary must have an entry with key `"cirq_type"`, to be used during loading. The latter looks for a magic classmethod `_from_json_dict_(cls, **kwargs)` or falls back to calling the constructor. `_json_dict_` is required for objects to support serialization, so `cirq.protocols.to_json_dict` is provided as convenience to implementers. You must only provide a list of attribute names which will be included in the dictionary. During deserialization, the `"cirq_type"` value is passed to an ordered list of functions. Each function either *resolves* the string name into a python type or the next function is called. A default mapping from class name is provided as a dictionary. There is a test fixture that tests any object exposed at the top `cirq`-level import by roundtripping to JSON. There is a blacklist of classes for which it doesn't make sense to serialize, and another "xfail" list for objects that haven't been upgraded to support serialization yet, but there's no reason not for them to. This pull request includes support for Qubits, Gates, Operations, Moments, Circuits, PauliString, and a couple of other things. I recommend reviewing and merging this PR which extends support for these high-priority objects and then working on the rest over time. Any new objects should implement the serialization protocol unless there's good reason not to. This pull request makes some small changes to objects, mostly adding `@property`s to provide a consistent expectation that arguments passed to the constructor are available under the same name as attributes/properties.
1 parent 6f09848 commit 237f64f

27 files changed

+913
-12
lines changed

cirq/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@
335335
qasm,
336336
QasmArgs,
337337
qid_shape,
338+
read_json,
338339
resolve_parameters,
339340
SupportsApplyChannel,
340341
SupportsConsistentApplyUnitary,
@@ -353,6 +354,8 @@
353354
SupportsQasmWithArgsAndQubits,
354355
SupportsTraceDistanceBound,
355356
SupportsUnitary,
357+
to_json,
358+
to_json_dict,
356359
trace_distance_bound,
357360
unitary,
358361
validate_mixture,

cirq/circuits/circuit.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,13 @@ def save_qasm(self,
15781578
"""
15791579
self._to_qasm_output(header, precision, qubit_order).save(file_path)
15801580

1581+
@property
1582+
def moments(self):
1583+
return self._moments
1584+
1585+
def _json_dict_(self):
1586+
return protocols.to_json_dict(self, ['moments', 'device'])
1587+
15811588

15821589
def _resolve_operations(
15831590
operations: Iterable[ops.Operation],

cirq/circuits/circuit_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3180,3 +3180,20 @@ def test_deprecated_to_unitary_matrix():
31803180
def test_deprecated_apply_unitary_effect_to_state():
31813181
np.testing.assert_allclose(cirq.Circuit().apply_unitary_effect_to_state(),
31823182
cirq.Circuit().final_wavefunction())
3183+
3184+
3185+
def test_moments_property():
3186+
q = cirq.NamedQubit('q')
3187+
c = cirq.Circuit.from_ops(cirq.X(q), cirq.Y(q))
3188+
assert c.moments[0] == cirq.Moment([cirq.X(q)])
3189+
assert c.moments[1] == cirq.Moment([cirq.Y(q)])
3190+
3191+
3192+
def test_json_dict():
3193+
q0, q1 = cirq.LineQubit.range(2)
3194+
c = cirq.Circuit.from_ops(cirq.CNOT(q0, q1))
3195+
assert c._json_dict_() == {
3196+
'cirq_type': 'Circuit',
3197+
'moments': [cirq.Moment([cirq.CNOT(q0, q1)])],
3198+
'device': cirq.UNCONSTRAINED_DEVICE,
3199+
}

cirq/devices/grid_qubit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from typing import Dict, List, Tuple
1717

18-
from cirq import ops
18+
from cirq import ops, protocols
1919

2020

2121
class GridQubit(ops.Qid):
@@ -139,6 +139,9 @@ def __repr__(self):
139139
def __str__(self):
140140
return '({}, {})'.format(self.row, self.col)
141141

142+
def _json_dict_(self):
143+
return protocols.to_json_dict(self, ['row', 'col'])
144+
142145
def __add__(self, other: Tuple[int, int]) -> 'GridQubit':
143146
if not (isinstance(other, tuple) and len(other) == 2 and
144147
all(isinstance(x, int) for x in other)):

cirq/devices/grid_qubit_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,13 @@ def test_from_proto_bad_dict():
167167
cirq.GridQubit.from_proto_dict({})
168168
with pytest.raises(ValueError):
169169
cirq.GridQubit.from_proto_dict({'nothing': 1})
170+
171+
172+
def test_to_json():
173+
q = cirq.GridQubit(5, 6)
174+
d = q._json_dict_()
175+
assert d == {
176+
'cirq_type': 'GridQubit',
177+
'row': 5,
178+
'col': 6,
179+
}

cirq/devices/line_qubit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import functools
1616
from typing import List
1717

18-
from cirq import ops
18+
from cirq import ops, protocols
1919

2020

2121
@functools.total_ordering
@@ -85,3 +85,6 @@ def __rsub__(self, other: int) -> 'LineQubit':
8585

8686
def __neg__(self) -> 'LineQubit':
8787
return LineQubit(-self.x)
88+
89+
def _json_dict_(self):
90+
return protocols.to_json_dict(self, ['x'])

cirq/devices/line_qubit_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,10 @@ def test_addition_subtraction_type_error():
107107

108108
def test_neg():
109109
assert -cirq.LineQubit(1) == cirq.LineQubit(-1)
110+
111+
112+
def test_json_dict():
113+
assert cirq.LineQubit(5)._json_dict_() == {
114+
'cirq_type': 'LineQubit',
115+
'x': 5,
116+
}

cirq/devices/unconstrained_device.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from cirq import value
15+
from cirq import value, protocols
1616
from cirq.devices import device
1717

1818

19+
@value.value_equality()
1920
class _UnconstrainedDevice(device.Device):
2021
"""A device that allows everything, infinitely fast."""
2122

@@ -37,5 +38,11 @@ def validate_schedule(self, schedule):
3738
def __repr__(self):
3839
return 'cirq.UNCONSTRAINED_DEVICE'
3940

41+
def _value_equality_values_(self):
42+
return ()
43+
44+
def _json_dict_(self):
45+
return protocols.to_json_dict(self, [])
46+
4047

4148
UNCONSTRAINED_DEVICE: device.Device = _UnconstrainedDevice()

cirq/ops/common_gates.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,20 @@ def __repr__(self):
552552
def _value_equality_values_(self):
553553
return self.num_qubits(), self.key, self.invert_mask
554554

555+
def _json_dict_(self):
556+
return {
557+
'cirq_type': self.__class__.__name__,
558+
'num_qubits': self.num_qubits(),
559+
'key': self.key,
560+
'invert_mask': self.invert_mask
561+
}
562+
563+
@classmethod
564+
def _from_json_dict_(cls, num_qubits, key, invert_mask, **kwargs):
565+
return cls(num_qubits=num_qubits,
566+
key=key,
567+
invert_mask=tuple(invert_mask))
568+
555569

556570
def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
557571
return ','.join(str(q) for q in qubits)
@@ -697,6 +711,12 @@ def _qasm_(self, args: protocols.QasmArgs,
697711
def _value_equality_values_(self):
698712
return self.num_qubits(),
699713

714+
def _json_dict_(self):
715+
return {
716+
'cirq_type': self.__class__.__name__,
717+
'num_qubits': self.num_qubits(),
718+
}
719+
700720

701721
class HPowGate(eigen_gate.EigenGate, gate_features.SingleQubitGate):
702722
"""A Gate that performs a rotation around the X+Z axis of the Bloch sphere.

cirq/ops/eigen_gate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def __init__(
111111
def exponent(self) -> value.TParamVal:
112112
return self._exponent
113113

114+
@property
115+
def global_shift(self) -> float:
116+
return self._global_shift
117+
114118
# virtual method
115119
def _with_exponent(self: TSelf, exponent: value.TParamVal) -> TSelf:
116120
"""Return the same kind of gate, but with a different exponent.
@@ -319,6 +323,9 @@ def _resolve_parameters_(self: TSelf, param_resolver) -> TSelf:
319323
return self._with_exponent(
320324
exponent=param_resolver.value_of(self._exponent))
321325

326+
def _json_dict_(self):
327+
return protocols.to_json_dict(self, ['exponent', 'global_shift'])
328+
322329

323330
def _lcm(vals: Iterable[int]) -> int:
324331
t = 1

cirq/ops/fsim_gate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def __repr__(self):
154154
return 'cirq.FSimGate(theta={}, phi={})'.format(proper_repr(self.theta),
155155
proper_repr(self.phi))
156156

157+
def _json_dict_(self):
158+
return protocols.to_json_dict(self, ['theta', 'phi'])
159+
157160

158161
def _format_rads(args: 'cirq.CircuitDiagramInfoArgs', radians: float) -> str:
159162
if cirq.is_parameterized(radians):

cirq/ops/fsim_gate_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,11 @@ def test_fsim_iswap_cphase(theta, phi):
234234
iswap_cphase = cirq.Circuit.from_ops((iswap.on(q0, q1), cphase.on(q0, q1)))
235235
fsim = cirq.FSimGate(theta=theta, phi=phi)
236236
assert np.allclose(cirq.unitary(iswap_cphase), cirq.unitary(fsim))
237+
238+
239+
def test_fsim_json_dict():
240+
assert cirq.FSimGate(theta=0.123, phi=0.456)._json_dict_() == {
241+
'cirq_type': 'FSimGate',
242+
'theta': 0.123,
243+
'phi': 0.456,
244+
}

cirq/ops/gate_operation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def __str__(self):
7171
return '{}({})'.format(self.gate,
7272
', '.join(str(e) for e in self.qubits))
7373

74+
def _json_dict_(self):
75+
return protocols.to_json_dict(self, ['gate', 'qubits'])
76+
7477
def _group_interchangeable_qubits(self) -> Tuple[
7578
Union[raw_types.Qid,
7679
Tuple[int, FrozenSet[raw_types.Qid]]],

cirq/ops/global_phase_op.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import numpy as np
1717

18-
from cirq import value
18+
from cirq import value, protocols
1919
from cirq.ops import raw_types
2020

2121

@@ -62,3 +62,6 @@ def __str__(self):
6262

6363
def __repr__(self):
6464
return 'cirq.GlobalPhaseOperation({!r})'.format(self.coefficient)
65+
66+
def _json_dict_(self):
67+
return protocols.to_json_dict(self, ['coefficient'])

cirq/ops/global_phase_op_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,10 @@ def test_diagram():
182182
]), """
183183
global phase: -0.5π
184184
""")
185+
186+
187+
def test_global_phase_op_json_dict():
188+
assert cirq.GlobalPhaseOperation(-1j)._json_dict_() == {
189+
'cirq_type': 'GlobalPhaseOperation',
190+
'coefficient': -1j,
191+
}

cirq/ops/moment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Any, Callable, Iterable, Sequence, TypeVar, Union
1818

19+
from cirq import protocols
1920
from cirq.protocols import approx_eq
2021
from cirq.ops import raw_types
2122

@@ -147,6 +148,9 @@ def transform_qubits(self: TSelf_Moment,
147148
return self.__class__(op.transform_qubits(func)
148149
for op in self.operations)
149150

151+
def _json_dict_(self):
152+
return protocols.to_json_dict(self, ['operations'])
153+
150154

151155
def _list_repr_with_indented_item_lines(items: Sequence[Any]) -> str:
152156
block = '\n'.join([repr(op) + ',' for op in items])

cirq/ops/moment_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def test_qubits():
203203
a = cirq.NamedQubit('a')
204204
b = cirq.NamedQubit('b')
205205

206-
assert Moment([cirq.X(a), cirq.X(b)]).qubits == {a , b}
206+
assert Moment([cirq.X(a), cirq.X(b)]).qubits == {a, b}
207207
assert Moment([cirq.X(a)]).qubits == {a}
208208
assert Moment([cirq.CZ(a, b)]).qubits == {a, b}
209209

@@ -225,3 +225,13 @@ def test_bool():
225225
assert not Moment()
226226
a = cirq.NamedQubit('a')
227227
assert Moment([cirq.X(a)])
228+
229+
230+
def test_json_dict():
231+
a = cirq.NamedQubit('a')
232+
b = cirq.NamedQubit('b')
233+
mom = Moment([cirq.CZ(a, b)])
234+
assert mom._json_dict_() == {
235+
'cirq_type': 'Moment',
236+
'operations': (cirq.CZ(a, b),)
237+
}

cirq/ops/named_qubit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from cirq import protocols
1515
from cirq.ops import raw_types
1616

1717

@@ -60,6 +60,9 @@ def range(*args, prefix: str):
6060
"""
6161
return [NamedQubit(prefix + str(i)) for i in range(*args)]
6262

63+
def _json_dict_(self):
64+
return protocols.to_json_dict(self, ['name'])
65+
6366

6467
def _pad_digits(text: str) -> str:
6568
"""A str method with hacks to support better lexicographic ordering.

cirq/ops/pauli_gates.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def __pow__(self: '_PauliX',
9696
exponent: value.TParamVal) -> common_gates.XPowGate:
9797
return common_gates.XPowGate(exponent=exponent)
9898

99+
@classmethod
100+
def _from_json_dict_(cls, exponent, global_shift, **kwargs):
101+
assert global_shift == 0
102+
return cls(exponent=exponent)
103+
99104

100105
class _PauliY(Pauli, common_gates.YPowGate):
101106

@@ -107,6 +112,11 @@ def __pow__(self: '_PauliY',
107112
exponent: value.TParamVal) -> common_gates.YPowGate:
108113
return common_gates.YPowGate(exponent=exponent)
109114

115+
@classmethod
116+
def _from_json_dict_(cls, exponent, global_shift, **kwargs):
117+
assert global_shift == 0
118+
return cls(exponent=exponent)
119+
110120

111121
class _PauliZ(Pauli, common_gates.ZPowGate):
112122

@@ -118,6 +128,11 @@ def __pow__(self: '_PauliZ',
118128
exponent: value.TParamVal) -> common_gates.ZPowGate:
119129
return common_gates.ZPowGate(exponent=exponent)
120130

131+
@classmethod
132+
def _from_json_dict_(cls, exponent, global_shift, **kwargs):
133+
assert global_shift == 0
134+
return cls(exponent=exponent)
135+
121136

122137
# The Pauli X gate.
123138
#
@@ -127,7 +142,6 @@ def __pow__(self: '_PauliZ',
127142
# [1, 0]]
128143
X = _PauliX()
129144

130-
131145
# The Pauli Y gate.
132146
#
133147
# Matrix:
@@ -136,7 +150,6 @@ def __pow__(self: '_PauliZ',
136150
# [i, 0]]
137151
Y = _PauliY()
138152

139-
140153
# The Pauli Z gate.
141154
#
142155
# Matrix:
@@ -145,5 +158,4 @@ def __pow__(self: '_PauliZ',
145158
# [0, -1]]
146159
Z = _PauliZ()
147160

148-
149161
Pauli._XYZ = (X, Y, Z)

0 commit comments

Comments
 (0)