Skip to content

Commit 2ef1909

Browse files
Add readout plotting tools (#6425)
1 parent ee56c59 commit 2ef1909

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

cirq-core/cirq/experiments/single_qubit_readout_calibration.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
import sympy
2020
import numpy as np
21+
import matplotlib.pyplot as plt
22+
import cirq.vis.heatmap as cirq_heatmap
23+
import cirq.vis.histogram as cirq_histogram
24+
from cirq.devices import grid_qubit
2125
from cirq import circuits, ops, study
2226

2327
if TYPE_CHECKING:
@@ -51,6 +55,124 @@ def _json_dict_(self) -> Dict[str, Any]:
5155
'timestamp': self.timestamp,
5256
}
5357

58+
def plot_heatmap(
59+
self,
60+
axs: Optional[tuple[plt.Axes, plt.Axes]] = None,
61+
annotation_format: str = '0.1%',
62+
**plot_kwargs: Any,
63+
) -> tuple[plt.Axes, plt.Axes]:
64+
"""Plot a heatmap of the readout errors. If qubits are not cirq.GridQubits, throws an error.
65+
66+
Args:
67+
axs: a tuple of the plt.Axes to plot on. If not given, a new figure is created,
68+
plotted on, and shown.
69+
annotation_format: The format string for the numbers in the heatmap.
70+
**plot_kwargs: Arguments to be passed to 'cirq.Heatmap.plot()'.
71+
Returns:
72+
The two plt.Axes containing the plot.
73+
74+
Raises:
75+
ValueError: axs does not contain two plt.Axes
76+
TypeError: qubits are not cirq.GridQubits
77+
"""
78+
79+
if axs is None:
80+
_, axs = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4))
81+
82+
else:
83+
if (
84+
not isinstance(axs, (tuple, list, np.ndarray))
85+
or len(axs) != 2
86+
or type(axs[0]) != plt.Axes
87+
or type(axs[1]) != plt.Axes
88+
): # pragma: no cover
89+
raise ValueError('axs should be a length-2 tuple of plt.Axes') # pragma: no cover
90+
for ax, title, data in zip(
91+
axs,
92+
['$|0\\rangle$ errors', '$|1\\rangle$ errors'],
93+
[self.zero_state_errors, self.one_state_errors],
94+
):
95+
data_with_grid_qubit_keys = {}
96+
for qubit in data:
97+
if type(qubit) != grid_qubit.GridQubit:
98+
raise TypeError(f'{qubit} must be of type cirq.GridQubit') # pragma: no cover
99+
data_with_grid_qubit_keys[qubit] = data[qubit] # just for typecheck
100+
_, _ = cirq_heatmap.Heatmap(data_with_grid_qubit_keys).plot(
101+
ax, annotation_format=annotation_format, title=title, **plot_kwargs
102+
)
103+
return axs[0], axs[1]
104+
105+
def plot_integrated_histogram(
106+
self,
107+
ax: Optional[plt.Axes] = None,
108+
cdf_on_x: bool = False,
109+
axis_label: str = 'Readout error rate',
110+
semilog: bool = True,
111+
median_line: bool = True,
112+
median_label: Optional[str] = 'median',
113+
mean_line: bool = False,
114+
mean_label: Optional[str] = 'mean',
115+
show_zero: bool = False,
116+
title: Optional[str] = None,
117+
**kwargs,
118+
) -> plt.Axes:
119+
"""Plot the readout errors using cirq.integrated_histogram().
120+
121+
Args:
122+
ax: The axis to plot on. If None, we generate one.
123+
cdf_on_x: If True, flip the axes compared the above example.
124+
axis_label: Label for x axis (y-axis if cdf_on_x is True).
125+
semilog: If True, force the x-axis to be logarithmic.
126+
median_line: If True, draw a vertical line on the median value.
127+
median_label: If drawing median line, optional label for it.
128+
mean_line: If True, draw a vertical line on the mean value.
129+
mean_label: If drawing mean line, optional label for it.
130+
title: Title of the plot. If None, we assign "N={len(data)}".
131+
show_zero: If True, moves the step plot up by one unit by prepending 0
132+
to the data.
133+
**kwargs: Kwargs to forward to `ax.step()`. Some examples are
134+
color: Color of the line.
135+
linestyle: Linestyle to use for the plot.
136+
lw: linewidth for integrated histogram.
137+
ms: marker size for a histogram trace.
138+
Returns:
139+
The axis that was plotted on.
140+
"""
141+
142+
ax = cirq_histogram.integrated_histogram(
143+
data=self.zero_state_errors,
144+
ax=ax,
145+
cdf_on_x=cdf_on_x,
146+
semilog=semilog,
147+
median_line=median_line,
148+
median_label=median_label,
149+
mean_line=mean_line,
150+
mean_label=mean_label,
151+
show_zero=show_zero,
152+
color='C0',
153+
label='$|0\\rangle$ errors',
154+
**kwargs,
155+
)
156+
ax = cirq_histogram.integrated_histogram(
157+
data=self.one_state_errors,
158+
ax=ax,
159+
cdf_on_x=cdf_on_x,
160+
axis_label=axis_label,
161+
semilog=semilog,
162+
median_line=median_line,
163+
median_label=median_label,
164+
mean_line=mean_line,
165+
mean_label=mean_label,
166+
show_zero=show_zero,
167+
title=title,
168+
color='C1',
169+
label='$|1\\rangle$ errors',
170+
**kwargs,
171+
)
172+
ax.legend(loc='best')
173+
ax.set_ylabel('Percentile')
174+
return ax
175+
54176
@classmethod
55177
def _from_json_dict_(
56178
cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs

cirq-core/cirq/experiments/single_qubit_readout_calibration_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_estimate_single_qubit_readout_errors_with_noise():
8787

8888

8989
def test_estimate_parallel_readout_errors_no_noise():
90-
qubits = cirq.LineQubit.range(10)
90+
qubits = [cirq.GridQubit(i, 0) for i in range(10)]
9191
sampler = cirq.Simulator()
9292
repetitions = 1000
9393
result = cirq.estimate_parallel_single_qubit_readout_errors(
@@ -97,6 +97,8 @@ def test_estimate_parallel_readout_errors_no_noise():
9797
assert result.one_state_errors == {q: 0 for q in qubits}
9898
assert result.repetitions == repetitions
9999
assert isinstance(result.timestamp, float)
100+
_ = result.plot_integrated_histogram()
101+
_, _ = result.plot_heatmap()
100102

101103

102104
def test_estimate_parallel_readout_errors_all_zeros():

0 commit comments

Comments
 (0)