Skip to content

Fix call signature on Gate to show that it takes Qids. #5235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cirq-core/cirq/contrib/quil_import/quil.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, Union
from typing import Callable, cast, Dict, Union

import numpy as np
from pyquil.parser import parse
Expand Down Expand Up @@ -258,7 +258,7 @@ def circuit_from_quil(quil: str) -> Circuit:
raise UndefinedQuilGate(f"Quil gate {quil_gate_name} not supported in Cirq.")
cirq_gate_fn = defined_gates[quil_gate_name]
if quil_gate_params:
circuit += cirq_gate_fn(*quil_gate_params)(*line_qubits)
circuit += cast(Callable[..., Gate], cirq_gate_fn)(*quil_gate_params)(*line_qubits)
else:
circuit += cirq_gate_fn(*line_qubits)

Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _decompose_(self, qubits):
)
return decomposed

def on(self, *qubits: 'cirq.Qid') -> cop.ControlledOperation:
def on(self, *qubits: 'cirq.Qid', **kwargs) -> cop.ControlledOperation:
if len(qubits) == 0:
raise ValueError(f"Applied a gate to an empty set of qubits. Gate: {self!r}")
self.validate_args(qubits)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def tensor_product(self, other: 'BaseDensePauliString') -> 'DensePauliString':
def __abs__(self):
return DensePauliString(coefficient=abs(self.coefficient), pauli_mask=self.pauli_mask)

def on(self, *qubits) -> 'cirq.PauliString':
def on(self, *qubits: 'cirq.Qid', **kwargs) -> 'cirq.PauliString':
return self.sparse(qubits)

def sparse(self, qubits: Optional[Sequence['cirq.Qid']] = None) -> 'cirq.PauliString':
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/ops/pauli_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ def __lt__(self, other):
return NotImplemented
return (other._index - self._index) % 3 == 1

def on(self, *qubits: 'cirq.Qid') -> 'SingleQubitPauliStringGateOperation':
def on(self, *qubits: 'cirq.Qid', **kwargs) -> 'SingleQubitPauliStringGateOperation':
"""Returns an application of this gate to the given qubits.

Args:
*qubits: The collection of qubits to potentially apply the gate to.
**kwargs: Unused.

Raises:
ValueError: If more than one qubit is acted upon.
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/pauli_string_phasor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def num_qubits(self) -> int:
"""The number of qubits for the gate."""
return len(self.dense_pauli_string)

def on(self, *qubits: 'cirq.Qid') -> 'cirq.PauliStringPhasor':
def on(self, *qubits: 'cirq.Qid', **kwargs) -> 'cirq.PauliStringPhasor':
"""Creates a PauliStringPhasor on the qubits."""
return PauliStringPhasor(
self.dense_pauli_string.on(*qubits),
Expand Down
8 changes: 5 additions & 3 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,13 @@ def validate_args(self, qubits: Sequence['cirq.Qid']) -> None:
"""
_validate_qid_shape(self, qubits)

def on(self, *qubits: Qid) -> 'Operation':
def on(self, *qubits: Qid, **kwargs) -> 'Operation':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change because all user defined classes which override the on method as def on(self, *qubits: Qid) will now not satisfy LSP and would need to be updated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah agree this is less than happy, let me see if I can work around the other issue (we currently call on in call with the wrong signature and get away with it)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it wasn't an issue earlier because mypy was not type-checking the method at all, due to no type annotations. But now that's not the case and hence it raises an error.

We can also decide to make __call__ a more restrictive type on the base class, so if user overrides on, they must also override __call__ -- but that would also be a breaking change.

Maybe @maffoo knows some mypy magic to get around the situation without having to deal with breaking changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @maffoo knows some mypy magic to get around the situation without having to deal with breaking changes.

I got nothin', sorry :-)

"""Returns an application of this gate to the given qubits.

Args:
*qubits: The collection of qubits to potentially apply the gate to.
**kwargs: Extra kwargs, can be used to explicitly labeled qubits like
control or target in controlled gates.
"""
return ops.gate_operation.GateOperation(self, list(qubits))

Expand Down Expand Up @@ -319,8 +321,8 @@ def __pow__(self, power):

return NotImplemented

def __call__(self, *args, **kwargs):
return self.on(*args, **kwargs)
def __call__(self, *qubits: Qid, **kwargs):
return self.on(*qubits, **kwargs)

def with_probability(self, probability: 'cirq.TParamVal') -> 'cirq.Gate':

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
(1, np.diag([0, 1])),
]

def on(self, *qubits: 'cirq.Qid') -> 'cirq.Operation':
def on(self, *qubits: 'cirq.Qid', **kwargs) -> 'cirq.Operation':
return BadGateOperation(self, list(qubits))

def controlled(
Expand Down