From 9ed0fa56ebd2eac844bee62fdfa3f0442dd6e773 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Mon, 14 Oct 2024 14:28:20 -0700 Subject: [PATCH] Use pytest fixture for multiprocessing Pool in XEB tests Direct Pool executions causes flaky test outcomes in internal test framework. This change makes it easier to work around the problem. No change in the effective test code. --- .../cirq/experiments/xeb_fitting_test.py | 44 ++++++++++--------- .../cirq/experiments/xeb_simulation_test.py | 29 ++++++------ 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/cirq-core/cirq/experiments/xeb_fitting_test.py b/cirq-core/cirq/experiments/xeb_fitting_test.py index f99bc86d7ca..399499b3039 100644 --- a/cirq-core/cirq/experiments/xeb_fitting_test.py +++ b/cirq-core/cirq/experiments/xeb_fitting_test.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import multiprocessing -from typing import Iterable +from typing import Iterable, Iterator import networkx as nx import numpy as np @@ -40,6 +42,13 @@ _POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count()) +@pytest.fixture +def pool() -> Iterator[multiprocessing.pool.Pool]: + ctx = multiprocessing.get_context() + with ctx.Pool(_POOL_NUM_PROCESSES) as pool: + yield pool + + @pytest.fixture(scope='module') def circuits_cycle_depths_sampled_df(): q0, q1 = cirq.LineQubit.range(2) @@ -207,7 +216,7 @@ def test_get_initial_simplex(): assert simplex.shape[1] == len(names) -def test_characterize_phased_fsim_parameters_with_xeb(): +def test_characterize_phased_fsim_parameters_with_xeb(pool): q0, q1 = cirq.LineQubit.range(2) rs = np.random.RandomState(52) circuits = [ @@ -232,17 +241,16 @@ def test_characterize_phased_fsim_parameters_with_xeb(): characterize_phi=False, ) p_circuits = [parameterize_circuit(circuit, options) for circuit in circuits] - with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool: - result = characterize_phased_fsim_parameters_with_xeb( - sampled_df=sampled_df, - parameterized_circuits=p_circuits, - cycle_depths=cycle_depths, - options=options, - # speed up with looser tolerances: - fatol=1e-2, - xatol=1e-2, - pool=pool, - ) + result = characterize_phased_fsim_parameters_with_xeb( + sampled_df=sampled_df, + parameterized_circuits=p_circuits, + cycle_depths=cycle_depths, + options=options, + # speed up with looser tolerances: + fatol=1e-2, + xatol=1e-2, + pool=pool, + ) opt_res = result.optimization_results[(q0, q1)] assert np.abs(opt_res.x[0] + np.pi / 4) < 0.1 assert np.abs(opt_res.fun) < 0.1 # noiseless simulator @@ -252,7 +260,7 @@ def test_characterize_phased_fsim_parameters_with_xeb(): @pytest.mark.parametrize('use_pool', (True, False)) -def test_parallel_full_workflow(use_pool): +def test_parallel_full_workflow(request, use_pool): circuits = rqcg.generate_library_of_2q_circuits( n_library_circuits=5, two_qubit_gate=cirq.ISWAP**0.5, @@ -272,10 +280,8 @@ def test_parallel_full_workflow(use_pool): combinations_by_layer=combs, ) - if use_pool: - pool = multiprocessing.Pool(_POOL_NUM_PROCESSES) - else: - pool = None + # avoid starting worker pool if it is not needed + pool = request.getfixturevalue("pool") if use_pool else None fids_df_0 = benchmark_2q_xeb_fidelities( sampled_df=sampled_df, circuits=circuits, cycle_depths=cycle_depths, pool=pool @@ -296,8 +302,6 @@ def test_parallel_full_workflow(use_pool): xatol=5e-2, pool=pool, ) - if pool is not None: - pool.terminate() assert len(result.optimization_results) == graph.number_of_edges() for opt_res in result.optimization_results.values(): diff --git a/cirq-core/cirq/experiments/xeb_simulation_test.py b/cirq-core/cirq/experiments/xeb_simulation_test.py index f2c89333e5f..1ef83678049 100644 --- a/cirq-core/cirq/experiments/xeb_simulation_test.py +++ b/cirq-core/cirq/experiments/xeb_simulation_test.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import multiprocessing -from typing import Dict, Any, Optional -from typing import Sequence +from typing import Any, Dict, Iterator, Optional, Sequence import numpy as np import pandas as pd @@ -27,7 +28,14 @@ _POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count()) -def test_simulate_2q_xeb_circuits(): +@pytest.fixture +def pool() -> Iterator[multiprocessing.pool.Pool]: + ctx = multiprocessing.get_context() + with ctx.Pool(_POOL_NUM_PROCESSES) as pool: + yield pool + + +def test_simulate_2q_xeb_circuits(pool): q0, q1 = cirq.LineQubit.range(2) circuits = [ rqcg.random_rotations_between_two_qubit_circuit( @@ -45,8 +53,7 @@ def test_simulate_2q_xeb_circuits(): assert len(row['pure_probs']) == 4 assert np.isclose(np.sum(row['pure_probs']), 1) - with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool: - df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool) + df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool) pd.testing.assert_frame_equal(df, df2) @@ -121,8 +128,8 @@ def _ref_simulate_2q_xeb_circuits( return pd.DataFrame(records).set_index(['circuit_i', 'cycle_depth']).sort_index() -@pytest.mark.parametrize('multiprocess', (True, False)) -def test_incremental_simulate(multiprocess): +@pytest.mark.parametrize('use_pool', (True, False)) +def test_incremental_simulate(request, use_pool): q0, q1 = cirq.LineQubit.range(2) circuits = [ rqcg.random_rotations_between_two_qubit_circuit( @@ -132,16 +139,12 @@ def test_incremental_simulate(multiprocess): ] cycle_depths = np.arange(3, 100, 9, dtype=np.int64) - if multiprocess: - pool = multiprocessing.Pool(_POOL_NUM_PROCESSES) - else: - pool = None + # avoid starting worker pool if it is not needed + pool = request.getfixturevalue("pool") if use_pool else None df_ref = _ref_simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool) df = simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool) - if pool is not None: - pool.terminate() pd.testing.assert_frame_equal(df_ref, df) # Use below for approximate equality, if e.g. you're using qsim: