Skip to content

Use pytest fixture for multiprocessing Pool in XEB tests #6766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions cirq-core/cirq/experiments/xeb_fitting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import itertools
import multiprocessing
from typing import Iterable
from typing import Iterable, Iterator

import networkx as nx
import numpy as np
Expand All @@ -40,6 +42,13 @@
_POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count())


@pytest.fixture
def pool() -> Iterator[multiprocessing.pool.Pool]:
ctx = multiprocessing.get_context()
with ctx.Pool(_POOL_NUM_PROCESSES) as pool:
yield pool


@pytest.fixture(scope='module')
def circuits_cycle_depths_sampled_df():
q0, q1 = cirq.LineQubit.range(2)
Expand Down Expand Up @@ -207,7 +216,7 @@ def test_get_initial_simplex():
assert simplex.shape[1] == len(names)


def test_characterize_phased_fsim_parameters_with_xeb():
def test_characterize_phased_fsim_parameters_with_xeb(pool):
q0, q1 = cirq.LineQubit.range(2)
rs = np.random.RandomState(52)
circuits = [
Expand All @@ -232,17 +241,16 @@ def test_characterize_phased_fsim_parameters_with_xeb():
characterize_phi=False,
)
p_circuits = [parameterize_circuit(circuit, options) for circuit in circuits]
with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool:
result = characterize_phased_fsim_parameters_with_xeb(
sampled_df=sampled_df,
parameterized_circuits=p_circuits,
cycle_depths=cycle_depths,
options=options,
# speed up with looser tolerances:
fatol=1e-2,
xatol=1e-2,
pool=pool,
)
result = characterize_phased_fsim_parameters_with_xeb(
sampled_df=sampled_df,
parameterized_circuits=p_circuits,
cycle_depths=cycle_depths,
options=options,
# speed up with looser tolerances:
fatol=1e-2,
xatol=1e-2,
pool=pool,
)
opt_res = result.optimization_results[(q0, q1)]
assert np.abs(opt_res.x[0] + np.pi / 4) < 0.1
assert np.abs(opt_res.fun) < 0.1 # noiseless simulator
Expand All @@ -252,7 +260,7 @@ def test_characterize_phased_fsim_parameters_with_xeb():


@pytest.mark.parametrize('use_pool', (True, False))
def test_parallel_full_workflow(use_pool):
def test_parallel_full_workflow(request, use_pool):
circuits = rqcg.generate_library_of_2q_circuits(
n_library_circuits=5,
two_qubit_gate=cirq.ISWAP**0.5,
Expand All @@ -272,10 +280,8 @@ def test_parallel_full_workflow(use_pool):
combinations_by_layer=combs,
)

if use_pool:
pool = multiprocessing.Pool(_POOL_NUM_PROCESSES)
else:
pool = None
# avoid starting worker pool if it is not needed
pool = request.getfixturevalue("pool") if use_pool else None

fids_df_0 = benchmark_2q_xeb_fidelities(
sampled_df=sampled_df, circuits=circuits, cycle_depths=cycle_depths, pool=pool
Expand All @@ -296,8 +302,6 @@ def test_parallel_full_workflow(use_pool):
xatol=5e-2,
pool=pool,
)
if pool is not None:
pool.terminate()

assert len(result.optimization_results) == graph.number_of_edges()
for opt_res in result.optimization_results.values():
Expand Down
29 changes: 16 additions & 13 deletions cirq-core/cirq/experiments/xeb_simulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import multiprocessing
from typing import Dict, Any, Optional
from typing import Sequence
from typing import Any, Dict, Iterator, Optional, Sequence

import numpy as np
import pandas as pd
Expand All @@ -27,7 +28,14 @@
_POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count())


def test_simulate_2q_xeb_circuits():
@pytest.fixture
def pool() -> Iterator[multiprocessing.pool.Pool]:
ctx = multiprocessing.get_context()
with ctx.Pool(_POOL_NUM_PROCESSES) as pool:
yield pool


def test_simulate_2q_xeb_circuits(pool):
q0, q1 = cirq.LineQubit.range(2)
circuits = [
rqcg.random_rotations_between_two_qubit_circuit(
Expand All @@ -45,8 +53,7 @@ def test_simulate_2q_xeb_circuits():
assert len(row['pure_probs']) == 4
assert np.isclose(np.sum(row['pure_probs']), 1)

with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool:
df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool)
df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool)

pd.testing.assert_frame_equal(df, df2)

Expand Down Expand Up @@ -121,8 +128,8 @@ def _ref_simulate_2q_xeb_circuits(
return pd.DataFrame(records).set_index(['circuit_i', 'cycle_depth']).sort_index()


@pytest.mark.parametrize('multiprocess', (True, False))
def test_incremental_simulate(multiprocess):
@pytest.mark.parametrize('use_pool', (True, False))
def test_incremental_simulate(request, use_pool):
q0, q1 = cirq.LineQubit.range(2)
circuits = [
rqcg.random_rotations_between_two_qubit_circuit(
Expand All @@ -132,16 +139,12 @@ def test_incremental_simulate(multiprocess):
]
cycle_depths = np.arange(3, 100, 9, dtype=np.int64)

if multiprocess:
pool = multiprocessing.Pool(_POOL_NUM_PROCESSES)
else:
pool = None
# avoid starting worker pool if it is not needed
pool = request.getfixturevalue("pool") if use_pool else None

df_ref = _ref_simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool)

df = simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool)
if pool is not None:
pool.terminate()
pd.testing.assert_frame_equal(df_ref, df)

# Use below for approximate equality, if e.g. you're using qsim:
Expand Down