diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 1a1b8c257e0..be91bf87688 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -589,12 +589,12 @@ def factor_state_vector( slices1 = (slice(None),) * n_axes + pivot[n_axes:] slices2 = pivot[:n_axes] + (slice(None),) * (t1.ndim - n_axes) extracted = t1[slices1] - extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5 + extracted = extracted / np.linalg.norm(extracted) remainder = t1[slices2] - remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5 + remainder = remainder / (np.linalg.norm(remainder) * t1[pivot] / abs(t1[pivot])) if validate: t2 = state_vector_kronecker_product(extracted, remainder) - if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol): + if not np.allclose(t2, t1, atol=atol): if not np.isclose(np.linalg.norm(t1), 1): raise ValueError('Input state must be normalized.') raise EntangledStateError('The tensor cannot be factored by the requested axes') diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index 6e13fb64d62..b92ff5faa13 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -613,3 +613,22 @@ def test_default_tolerance(): # Here, we do NOT specify the default tolerance. It is merely to check that the default value # is reasonable. cirq.sub_state_vector(final_state_vector, [0]) + + +@pytest.mark.parametrize('state_1', [0, 1]) +@pytest.mark.parametrize('state_2', [0, 1]) +def test_factor_state_vector(state_1: int, state_2: int): + # Kron two state vectors and apply a phase. Factoring should produce the expected results. + n = 12 + for i in range(n): + phase = np.exp(2 * np.pi * 1j * i / n) + a = cirq.to_valid_state_vector(state_1, 1) + b = cirq.to_valid_state_vector(state_2, 1) + c = cirq.linalg.transformations.state_vector_kronecker_product(a, b) * phase + a1, b1 = cirq.linalg.transformations.factor_state_vector(c, [0], validate=True) + c1 = cirq.linalg.transformations.state_vector_kronecker_product(a1, b1) + assert np.allclose(c, c1) + + # All phase goes into a1, and b1 is just the dephased state vector + assert np.allclose(a1, a * phase) + assert np.allclose(b1, b) diff --git a/cirq-core/cirq/sim/simulation_product_state.py b/cirq-core/cirq/sim/simulation_product_state.py index 1338d7ebf7a..c4421b41ec2 100644 --- a/cirq-core/cirq/sim/simulation_product_state.py +++ b/cirq-core/cirq/sim/simulation_product_state.py @@ -122,7 +122,7 @@ def _act_on_fallback_( gate_opt, (ops.ResetChannel, ops.MeasurementGate) ): for q in qubits: - if op_args.allows_factoring: + if op_args.allows_factoring and len(op_args.qubits) > 1: q_args, op_args = op_args.factor((q,), validate=False) self._sim_states[q] = q_args diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index 1ebef6b4240..ee16e285313 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -1434,3 +1434,19 @@ def test_unseparated_states_str(): qubits: (cirq.LineQubit(0), cirq.LineQubit(1)) output vector: 0.707j|00⟩ + 0.707j|10⟩""" ) + + +@pytest.mark.parametrize('split', [True, False]) +def test_measurement_preserves_phase(split: bool): + c1, c2, t = cirq.LineQubit.range(3) + circuit = cirq.Circuit( + cirq.H(t), + cirq.measure(t, key='t'), + cirq.CZ(c1, c2).with_classical_controls('t'), + cirq.reset(t), + ) + simulator = cirq.Simulator(split_untangled_states=split) + # Run enough times that both options of |110> - |111> are likely measured. + for _ in range(20): + result = simulator.simulate(circuit, initial_state=(1, 1, 1), qubit_order=(c1, c2, t)) + assert result.dirac_notation() == '|110⟩' diff --git a/docs/experiments/textbook_algorithms.ipynb b/docs/experiments/textbook_algorithms.ipynb index 2287dcc5fad..0949421cf64 100644 --- a/docs/experiments/textbook_algorithms.ipynb +++ b/docs/experiments/textbook_algorithms.ipynb @@ -233,7 +233,7 @@ "print(np.round(bobs_bloch_vector, 3))\n", "\n", "# Verify they are the same state!\n", - "np.testing.assert_allclose(bobs_bloch_vector, message_bloch_vector, atol=1e-7)" + "np.testing.assert_allclose(bobs_bloch_vector, message_bloch_vector, atol=1e-6)" ] }, {