Skip to content

Commit 27bfc59

Browse files
authored
Fix phase in factor (quantumlib#5847)
Fixes quantumlib#5834 Phase in input tensor was being allocated to both output tensors. This PR readjusts to remove the phase from the remainder tensor (a nice thing about this is that if the remainder _is_ nothing but a global phase, then it's `1`). The validation is updated to check `np.all_close` rather than `allclose_up_to_global_phase` (I also ran all simulator unit tests with "validate=True" locally and they all passed with these changes), and a unit test for quantumlib#5834 is added. Also added check that we aren't factoring out the last qubit during simulation and losing the remaining phase. This isn't strictly necessary since the remainder is guaranteed to be `1`, but prevent any surprises if that changes (and may as well skip it anyway for perf sake).
1 parent cc9865c commit 27bfc59

File tree

4 files changed

+39
-4
lines changed

4 files changed

+39
-4
lines changed

cirq/linalg/transformations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,12 +589,12 @@ def factor_state_vector(
589589
slices1 = (slice(None),) * n_axes + pivot[n_axes:]
590590
slices2 = pivot[:n_axes] + (slice(None),) * (t1.ndim - n_axes)
591591
extracted = t1[slices1]
592-
extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5
592+
extracted = extracted / np.linalg.norm(extracted)
593593
remainder = t1[slices2]
594-
remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5
594+
remainder = remainder / (np.linalg.norm(remainder) * t1[pivot] / abs(t1[pivot]))
595595
if validate:
596596
t2 = state_vector_kronecker_product(extracted, remainder)
597-
if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol):
597+
if not np.allclose(t2, t1, atol=atol):
598598
if not np.isclose(np.linalg.norm(t1), 1):
599599
raise ValueError('Input state must be normalized.')
600600
raise EntangledStateError('The tensor cannot be factored by the requested axes')

cirq/linalg/transformations_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,22 @@ def test_default_tolerance():
613613
# Here, we do NOT specify the default tolerance. It is merely to check that the default value
614614
# is reasonable.
615615
cirq.sub_state_vector(final_state_vector, [0])
616+
617+
618+
@pytest.mark.parametrize('state_1', [0, 1])
619+
@pytest.mark.parametrize('state_2', [0, 1])
620+
def test_factor_state_vector(state_1: int, state_2: int):
621+
# Kron two state vectors and apply a phase. Factoring should produce the expected results.
622+
n = 12
623+
for i in range(n):
624+
phase = np.exp(2 * np.pi * 1j * i / n)
625+
a = cirq.to_valid_state_vector(state_1, 1)
626+
b = cirq.to_valid_state_vector(state_2, 1)
627+
c = cirq.linalg.transformations.state_vector_kronecker_product(a, b) * phase
628+
a1, b1 = cirq.linalg.transformations.factor_state_vector(c, [0], validate=True)
629+
c1 = cirq.linalg.transformations.state_vector_kronecker_product(a1, b1)
630+
assert np.allclose(c, c1)
631+
632+
# All phase goes into a1, and b1 is just the dephased state vector
633+
assert np.allclose(a1, a * phase)
634+
assert np.allclose(b1, b)

cirq/sim/simulation_product_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _act_on_fallback_(
122122
gate_opt, (ops.ResetChannel, ops.MeasurementGate)
123123
):
124124
for q in qubits:
125-
if op_args.allows_factoring:
125+
if op_args.allows_factoring and len(op_args.qubits) > 1:
126126
q_args, op_args = op_args.factor((q,), validate=False)
127127
self._sim_states[q] = q_args
128128

cirq/sim/sparse_simulator_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,3 +1434,19 @@ def test_unseparated_states_str():
14341434
qubits: (cirq.LineQubit(0), cirq.LineQubit(1))
14351435
output vector: 0.707j|00⟩ + 0.707j|10⟩"""
14361436
)
1437+
1438+
1439+
@pytest.mark.parametrize('split', [True, False])
1440+
def test_measurement_preserves_phase(split: bool):
1441+
c1, c2, t = cirq.LineQubit.range(3)
1442+
circuit = cirq.Circuit(
1443+
cirq.H(t),
1444+
cirq.measure(t, key='t'),
1445+
cirq.CZ(c1, c2).with_classical_controls('t'),
1446+
cirq.reset(t),
1447+
)
1448+
simulator = cirq.Simulator(split_untangled_states=split)
1449+
# Run enough times that both options of |110> - |111> are likely measured.
1450+
for _ in range(20):
1451+
result = simulator.simulate(circuit, initial_state=(1, 1, 1), qubit_order=(c1, c2, t))
1452+
assert result.dirac_notation() == '|110⟩'

0 commit comments

Comments
 (0)