Skip to content

Create a new condition that allows easy control by bitmasks and Add a new classical Update the notebook for 'Classical control' to reflect new features" #7166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@
canonicalize_half_turns as canonicalize_half_turns,
chosen_angle_to_canonical_half_turns as chosen_angle_to_canonical_half_turns,
chosen_angle_to_half_turns as chosen_angle_to_half_turns,
BitMaskKeyCondition as BitMaskKeyCondition,
ClassicalDataDictionaryStore as ClassicalDataDictionaryStore,
ClassicalDataStore as ClassicalDataStore,
ClassicalDataStoreReader as ClassicalDataStoreReader,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _symmetricalqidpair(qids):
'AnyUnitaryGateFamily': cirq.AnyUnitaryGateFamily,
'AsymmetricDepolarizingChannel': cirq.AsymmetricDepolarizingChannel,
'BitFlipChannel': cirq.BitFlipChannel,
'BitMaskKeyCondition': cirq.BitMaskKeyCondition,
'BitstringAccumulator': cirq.work.BitstringAccumulator,
'BooleanHamiltonianGate': cirq.BooleanHamiltonianGate,
'CCNotPowGate': cirq.CCNotPowGate,
Expand Down
7 changes: 7 additions & 0 deletions cirq-core/cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Union,
)

import attrs
import numpy as np
import pandas as pd
import sympy
Expand Down Expand Up @@ -182,6 +183,12 @@ def dataclass_json_dict(obj: Any) -> Dict[str, Any]:
return obj_to_dict_helper(obj, attribute_names)


def attrs_json_dict(obj: Any) -> Dict[str, Any]:
"""Return a dictionary suitable for `_json_dict_` from an attrs dataclass."""
attribute_names = [f.name for f in attrs.fields(type(obj))]
return obj_to_dict_helper(obj, attribute_names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit test for this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



def _json_dict_with_cirq_type(obj: Any):
base_dict = obj._json_dict_()
if 'cirq_type' in base_dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
[
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "a",
"path": []
},
"index": 59,
"target_value": 0,
"equal_target": false,
"bitmask": null
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "b",
"path": []
},
"index": 58,
"target_value": 3,
"equal_target": false,
"bitmask": null
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "c",
"path": []
},
"index": 57,
"target_value": 0,
"equal_target": false,
"bitmask": 13
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "d",
"path": []
},
"index": 56,
"target_value": 12,
"equal_target": false,
"bitmask": 13
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "d",
"path": []
},
"index": 55,
"target_value": 12,
"equal_target": true,
"bitmask": 13
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "e",
"path": []
},
"index": 54,
"target_value": 11,
"equal_target": true,
"bitmask": 11
},
{
"cirq_type": "BitMaskKeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "e",
"path": []
},
"index": 53,
"target_value": 9,
"equal_target": false,
"bitmask": 9
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='a'), index=59, target_value=0, equal_target=False, bitmask=None),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='b'), index=58, target_value=3, equal_target=False, bitmask=None),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='c'), index=57, target_value=0, equal_target=False, bitmask=13),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='d'), index=56, target_value=12, equal_target=False, bitmask=13),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='d'), index=55, target_value=12, equal_target=True, bitmask=13),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='e'), index=54, target_value=11, equal_target=True, bitmask=11),
cirq.BitMaskKeyCondition(key=cirq.MeasurementKey(name='e'), index=53, target_value=9, equal_target=False, bitmask=9)]
1 change: 1 addition & 0 deletions cirq-core/cirq/value/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Condition as Condition,
KeyCondition as KeyCondition,
SympyCondition as SympyCondition,
BitMaskKeyCondition as BitMaskKeyCondition,
)

from cirq.value.digits import (
Expand Down
102 changes: 102 additions & 0 deletions cirq-core/cirq/value/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, TYPE_CHECKING

import attrs
import sympy

from cirq._compat import proper_repr
Expand Down Expand Up @@ -135,6 +136,107 @@ def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]:
return f'{key}==1'


@attrs.frozen
class BitMaskKeyCondition(Condition):
"""A multiqubit classical control condition with a bitmask.

The control is based on a single measurement key and allows comparing equality or inequality
after taking the bitwise and with a bitmask.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a short blurb about the bit order of the measurement (or a reference to the measurement code that explains)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


Examples:
- BitMaskKeycondition('a') -> a != 0
- BitMaskKeyCondition('a', bitmask=13) -> (a & 13) != 0
- BitMaskKeyCondition('a', bitmask=13, target_value=9) -> (a & 13) != 9
- BitMaskKeyCondition('a', bitmask=13, target_value=9, equal_target=True) -> (a & 13) == 9
- BitMaskKeyCondition.create_equal_mask('a', 13) -> (a & 13) == 13
- BitMaskKeyCondition.create_not_equal_mask('a', 13) -> (a & 13) != 13

Attributes:
- key: Measurement key.
- index: integer index (same as KeyCondition.index).
- target_value: The value we compare with.
- equal_target: Whether to comapre with == or !=.
- bitmask: Optional bitmask to apply before doing the comparison.
"""

key: 'cirq.MeasurementKey' = attrs.field(
converter=lambda x: (
x
if isinstance(x, measurement_key.MeasurementKey)
else measurement_key.MeasurementKey(x)
)
)
index: int = -1
target_value: int = 0
equal_target: bool = False
bitmask: Optional[int] = None

@property
def keys(self):
return (self.key,)

@staticmethod
def create_equal_mask(
key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
) -> 'BitMaskKeyCondition':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return BitMaskKeyCondition(
key, index, target_value=bitmask, equal_target=True, bitmask=bitmask
)

@staticmethod
def create_not_equal_mask(
key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
) -> 'BitMaskKeyCondition':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return BitMaskKeyCondition(
key, index, target_value=bitmask, equal_target=False, bitmask=bitmask
)

def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
return BitMaskKeyCondition(replacement) if self.key == current else self

def __str__(self):
s = str(self.key) if self.index == -1 else f'{self.key}[{self.index}]'
if self.bitmask is not None:
s = f'{s} & {self.bitmask}'
if self.equal_target:
if self.bitmask is not None:
s = f'({s})'
s = f'{s} == {self.target_value}'
elif self.target_value != 0:
if self.bitmask is not None:
s = f'({s})'
s = f'{s} != {self.target_value}'
return s

def __repr__(self):
values = attrs.asdict(self)
parameters = ', '.join(f'{f.name}={repr(values[f.name])}' for f in attrs.fields(type(self)))
return f'cirq.BitMaskKeyCondition({parameters})'

def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
if self.key not in classical_data.keys():
raise ValueError(f'Measurement key {self.key} missing when testing classical control')
value = classical_data.get_int(self.key, self.index)
if self.bitmask is not None:
value &= self.bitmask
if self.equal_target:
return value == self.target_value
return value != self.target_value

def _json_dict_(self):
return json_serialization.attrs_json_dict(self)

@classmethod
def _from_json_dict_(cls, key, **kwargs):
parameter_names = [f.name for f in attrs.fields(cls)[1:]]
parameters = {k: kwargs[k] for k in parameter_names if k in kwargs}
return cls(key=key, **parameters)

@property
def qasm(self):
raise NotImplementedError() # pragma: no cover


@dataclasses.dataclass(frozen=True)
class SympyCondition(Condition):
"""A classical control condition based on a sympy expression.
Expand Down
Loading