diff --git a/cirq-core/cirq/experiments/xeb_fitting_test.py b/cirq-core/cirq/experiments/xeb_fitting_test.py index f99bc86d7ca..399499b3039 100644 --- a/cirq-core/cirq/experiments/xeb_fitting_test.py +++ b/cirq-core/cirq/experiments/xeb_fitting_test.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import multiprocessing -from typing import Iterable +from typing import Iterable, Iterator import networkx as nx import numpy as np @@ -40,6 +42,13 @@ _POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count()) +@pytest.fixture +def pool() -> Iterator[multiprocessing.pool.Pool]: + ctx = multiprocessing.get_context() + with ctx.Pool(_POOL_NUM_PROCESSES) as pool: + yield pool + + @pytest.fixture(scope='module') def circuits_cycle_depths_sampled_df(): q0, q1 = cirq.LineQubit.range(2) @@ -207,7 +216,7 @@ def test_get_initial_simplex(): assert simplex.shape[1] == len(names) -def test_characterize_phased_fsim_parameters_with_xeb(): +def test_characterize_phased_fsim_parameters_with_xeb(pool): q0, q1 = cirq.LineQubit.range(2) rs = np.random.RandomState(52) circuits = [ @@ -232,17 +241,16 @@ def test_characterize_phased_fsim_parameters_with_xeb(): characterize_phi=False, ) p_circuits = [parameterize_circuit(circuit, options) for circuit in circuits] - with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool: - result = characterize_phased_fsim_parameters_with_xeb( - sampled_df=sampled_df, - parameterized_circuits=p_circuits, - cycle_depths=cycle_depths, - options=options, - # speed up with looser tolerances: - fatol=1e-2, - xatol=1e-2, - pool=pool, - ) + result = characterize_phased_fsim_parameters_with_xeb( + sampled_df=sampled_df, + parameterized_circuits=p_circuits, + cycle_depths=cycle_depths, + options=options, + # speed up with looser tolerances: + fatol=1e-2, + xatol=1e-2, + pool=pool, + ) opt_res = result.optimization_results[(q0, q1)] assert np.abs(opt_res.x[0] + np.pi / 4) < 0.1 assert np.abs(opt_res.fun) < 0.1 # noiseless simulator @@ -252,7 +260,7 @@ def test_characterize_phased_fsim_parameters_with_xeb(): @pytest.mark.parametrize('use_pool', (True, False)) -def test_parallel_full_workflow(use_pool): +def test_parallel_full_workflow(request, use_pool): circuits = rqcg.generate_library_of_2q_circuits( n_library_circuits=5, two_qubit_gate=cirq.ISWAP**0.5, @@ -272,10 +280,8 @@ def test_parallel_full_workflow(use_pool): combinations_by_layer=combs, ) - if use_pool: - pool = multiprocessing.Pool(_POOL_NUM_PROCESSES) - else: - pool = None + # avoid starting worker pool if it is not needed + pool = request.getfixturevalue("pool") if use_pool else None fids_df_0 = benchmark_2q_xeb_fidelities( sampled_df=sampled_df, circuits=circuits, cycle_depths=cycle_depths, pool=pool @@ -296,8 +302,6 @@ def test_parallel_full_workflow(use_pool): xatol=5e-2, pool=pool, ) - if pool is not None: - pool.terminate() assert len(result.optimization_results) == graph.number_of_edges() for opt_res in result.optimization_results.values(): diff --git a/cirq-core/cirq/experiments/xeb_simulation_test.py b/cirq-core/cirq/experiments/xeb_simulation_test.py index f2c89333e5f..1ef83678049 100644 --- a/cirq-core/cirq/experiments/xeb_simulation_test.py +++ b/cirq-core/cirq/experiments/xeb_simulation_test.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import multiprocessing -from typing import Dict, Any, Optional -from typing import Sequence +from typing import Any, Dict, Iterator, Optional, Sequence import numpy as np import pandas as pd @@ -27,7 +28,14 @@ _POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count()) -def test_simulate_2q_xeb_circuits(): +@pytest.fixture +def pool() -> Iterator[multiprocessing.pool.Pool]: + ctx = multiprocessing.get_context() + with ctx.Pool(_POOL_NUM_PROCESSES) as pool: + yield pool + + +def test_simulate_2q_xeb_circuits(pool): q0, q1 = cirq.LineQubit.range(2) circuits = [ rqcg.random_rotations_between_two_qubit_circuit( @@ -45,8 +53,7 @@ def test_simulate_2q_xeb_circuits(): assert len(row['pure_probs']) == 4 assert np.isclose(np.sum(row['pure_probs']), 1) - with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool: - df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool) + df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool) pd.testing.assert_frame_equal(df, df2) @@ -121,8 +128,8 @@ def _ref_simulate_2q_xeb_circuits( return pd.DataFrame(records).set_index(['circuit_i', 'cycle_depth']).sort_index() -@pytest.mark.parametrize('multiprocess', (True, False)) -def test_incremental_simulate(multiprocess): +@pytest.mark.parametrize('use_pool', (True, False)) +def test_incremental_simulate(request, use_pool): q0, q1 = cirq.LineQubit.range(2) circuits = [ rqcg.random_rotations_between_two_qubit_circuit( @@ -132,16 +139,12 @@ def test_incremental_simulate(multiprocess): ] cycle_depths = np.arange(3, 100, 9, dtype=np.int64) - if multiprocess: - pool = multiprocessing.Pool(_POOL_NUM_PROCESSES) - else: - pool = None + # avoid starting worker pool if it is not needed + pool = request.getfixturevalue("pool") if use_pool else None df_ref = _ref_simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool) df = simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool) - if pool is not None: - pool.terminate() pd.testing.assert_frame_equal(df_ref, df) # Use below for approximate equality, if e.g. you're using qsim: