|
18 | 18 |
|
19 | 19 | import sympy
|
20 | 20 | import numpy as np
|
| 21 | +import matplotlib.pyplot as plt |
| 22 | +import cirq.vis.heatmap as cirq_heatmap |
| 23 | +import cirq.vis.histogram as cirq_histogram |
| 24 | +from cirq.devices import grid_qubit |
21 | 25 | from cirq import circuits, ops, study
|
22 | 26 |
|
23 | 27 | if TYPE_CHECKING:
|
@@ -51,6 +55,124 @@ def _json_dict_(self) -> Dict[str, Any]:
|
51 | 55 | 'timestamp': self.timestamp,
|
52 | 56 | }
|
53 | 57 |
|
| 58 | + def plot_heatmap( |
| 59 | + self, |
| 60 | + axs: Optional[tuple[plt.Axes, plt.Axes]] = None, |
| 61 | + annotation_format: str = '0.1%', |
| 62 | + **plot_kwargs: Any, |
| 63 | + ) -> tuple[plt.Axes, plt.Axes]: |
| 64 | + """Plot a heatmap of the readout errors. If qubits are not cirq.GridQubits, throws an error. |
| 65 | +
|
| 66 | + Args: |
| 67 | + axs: a tuple of the plt.Axes to plot on. If not given, a new figure is created, |
| 68 | + plotted on, and shown. |
| 69 | + annotation_format: The format string for the numbers in the heatmap. |
| 70 | + **plot_kwargs: Arguments to be passed to 'cirq.Heatmap.plot()'. |
| 71 | + Returns: |
| 72 | + The two plt.Axes containing the plot. |
| 73 | +
|
| 74 | + Raises: |
| 75 | + ValueError: axs does not contain two plt.Axes |
| 76 | + TypeError: qubits are not cirq.GridQubits |
| 77 | + """ |
| 78 | + |
| 79 | + if axs is None: |
| 80 | + _, axs = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4)) |
| 81 | + |
| 82 | + else: |
| 83 | + if ( |
| 84 | + not isinstance(axs, (tuple, list, np.ndarray)) |
| 85 | + or len(axs) != 2 |
| 86 | + or type(axs[0]) != plt.Axes |
| 87 | + or type(axs[1]) != plt.Axes |
| 88 | + ): # pragma: no cover |
| 89 | + raise ValueError('axs should be a length-2 tuple of plt.Axes') # pragma: no cover |
| 90 | + for ax, title, data in zip( |
| 91 | + axs, |
| 92 | + ['$|0\\rangle$ errors', '$|1\\rangle$ errors'], |
| 93 | + [self.zero_state_errors, self.one_state_errors], |
| 94 | + ): |
| 95 | + data_with_grid_qubit_keys = {} |
| 96 | + for qubit in data: |
| 97 | + if type(qubit) != grid_qubit.GridQubit: |
| 98 | + raise TypeError(f'{qubit} must be of type cirq.GridQubit') # pragma: no cover |
| 99 | + data_with_grid_qubit_keys[qubit] = data[qubit] # just for typecheck |
| 100 | + _, _ = cirq_heatmap.Heatmap(data_with_grid_qubit_keys).plot( |
| 101 | + ax, annotation_format=annotation_format, title=title, **plot_kwargs |
| 102 | + ) |
| 103 | + return axs[0], axs[1] |
| 104 | + |
| 105 | + def plot_integrated_histogram( |
| 106 | + self, |
| 107 | + ax: Optional[plt.Axes] = None, |
| 108 | + cdf_on_x: bool = False, |
| 109 | + axis_label: str = 'Readout error rate', |
| 110 | + semilog: bool = True, |
| 111 | + median_line: bool = True, |
| 112 | + median_label: Optional[str] = 'median', |
| 113 | + mean_line: bool = False, |
| 114 | + mean_label: Optional[str] = 'mean', |
| 115 | + show_zero: bool = False, |
| 116 | + title: Optional[str] = None, |
| 117 | + **kwargs, |
| 118 | + ) -> plt.Axes: |
| 119 | + """Plot the readout errors using cirq.integrated_histogram(). |
| 120 | +
|
| 121 | + Args: |
| 122 | + ax: The axis to plot on. If None, we generate one. |
| 123 | + cdf_on_x: If True, flip the axes compared the above example. |
| 124 | + axis_label: Label for x axis (y-axis if cdf_on_x is True). |
| 125 | + semilog: If True, force the x-axis to be logarithmic. |
| 126 | + median_line: If True, draw a vertical line on the median value. |
| 127 | + median_label: If drawing median line, optional label for it. |
| 128 | + mean_line: If True, draw a vertical line on the mean value. |
| 129 | + mean_label: If drawing mean line, optional label for it. |
| 130 | + title: Title of the plot. If None, we assign "N={len(data)}". |
| 131 | + show_zero: If True, moves the step plot up by one unit by prepending 0 |
| 132 | + to the data. |
| 133 | + **kwargs: Kwargs to forward to `ax.step()`. Some examples are |
| 134 | + color: Color of the line. |
| 135 | + linestyle: Linestyle to use for the plot. |
| 136 | + lw: linewidth for integrated histogram. |
| 137 | + ms: marker size for a histogram trace. |
| 138 | + Returns: |
| 139 | + The axis that was plotted on. |
| 140 | + """ |
| 141 | + |
| 142 | + ax = cirq_histogram.integrated_histogram( |
| 143 | + data=self.zero_state_errors, |
| 144 | + ax=ax, |
| 145 | + cdf_on_x=cdf_on_x, |
| 146 | + semilog=semilog, |
| 147 | + median_line=median_line, |
| 148 | + median_label=median_label, |
| 149 | + mean_line=mean_line, |
| 150 | + mean_label=mean_label, |
| 151 | + show_zero=show_zero, |
| 152 | + color='C0', |
| 153 | + label='$|0\\rangle$ errors', |
| 154 | + **kwargs, |
| 155 | + ) |
| 156 | + ax = cirq_histogram.integrated_histogram( |
| 157 | + data=self.one_state_errors, |
| 158 | + ax=ax, |
| 159 | + cdf_on_x=cdf_on_x, |
| 160 | + axis_label=axis_label, |
| 161 | + semilog=semilog, |
| 162 | + median_line=median_line, |
| 163 | + median_label=median_label, |
| 164 | + mean_line=mean_line, |
| 165 | + mean_label=mean_label, |
| 166 | + show_zero=show_zero, |
| 167 | + title=title, |
| 168 | + color='C1', |
| 169 | + label='$|1\\rangle$ errors', |
| 170 | + **kwargs, |
| 171 | + ) |
| 172 | + ax.legend(loc='best') |
| 173 | + ax.set_ylabel('Percentile') |
| 174 | + return ax |
| 175 | + |
54 | 176 | @classmethod
|
55 | 177 | def _from_json_dict_(
|
56 | 178 | cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs
|
|
0 commit comments