|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import functools |
3 | 4 | import gc
|
4 | 5 | import os
|
5 | 6 | import pkgutil
|
|
16 | 17 | from ._wrapper import get_lib
|
17 | 18 |
|
18 | 19 | if TYPE_CHECKING:
|
19 |
| - from typing import Any, Callable, TypeVar |
| 20 | + from typing import Any, Callable, ParamSpec, TypeVar |
20 | 21 |
|
21 | 22 | from ._wrapper import LibType
|
22 | 23 |
|
23 | 24 | T = TypeVar("T")
|
| 25 | + P = ParamSpec("P") |
24 | 26 |
|
25 | 27 | IS_PYTEST_BENCHMARK_INSTALLED = pkgutil.find_loader("pytest_benchmark") is not None
|
26 | 28 | SUPPORTS_PERF_TRAMPOLINE = sys.version_info >= (3, 12)
|
@@ -172,86 +174,73 @@ def pytest_collection_modifyitems(
|
172 | 174 |
|
173 | 175 | def _run_with_instrumentation(
|
174 | 176 | lib: LibType,
|
175 |
| - nodeId: str, |
| 177 | + nodeid: str, |
176 | 178 | config: pytest.Config,
|
177 |
| - fn: Callable[..., Any], |
178 |
| - *args, |
179 |
| - **kwargs, |
180 |
| -): |
| 179 | + fn: Callable[P, T], |
| 180 | + *args: P.args, |
| 181 | + **kwargs: P.kwargs, |
| 182 | +) -> T: |
181 | 183 | is_gc_enabled = gc.isenabled()
|
182 | 184 | if is_gc_enabled:
|
183 | 185 | gc.collect()
|
184 | 186 | gc.disable()
|
185 | 187 |
|
186 |
| - result = None |
187 |
| - |
188 |
| - def __codspeed_root_frame__(): |
189 |
| - nonlocal result |
190 |
| - result = fn(*args, **kwargs) |
191 |
| - |
192 |
| - if SUPPORTS_PERF_TRAMPOLINE: |
193 |
| - # Warmup CPython performance map cache |
194 |
| - __codspeed_root_frame__() |
195 |
| - lib.zero_stats() |
196 |
| - lib.start_instrumentation() |
197 |
| - __codspeed_root_frame__() |
198 |
| - lib.stop_instrumentation() |
199 |
| - uri = get_git_relative_uri(nodeId, config.rootpath) |
200 |
| - lib.dump_stats_at(uri.encode("ascii")) |
201 |
| - if is_gc_enabled: |
202 |
| - gc.enable() |
| 188 | + def __codspeed_root_frame__() -> T: |
| 189 | + return fn(*args, **kwargs) |
| 190 | + |
| 191 | + try: |
| 192 | + if SUPPORTS_PERF_TRAMPOLINE: |
| 193 | + # Warmup CPython performance map cache |
| 194 | + __codspeed_root_frame__() |
| 195 | + |
| 196 | + lib.zero_stats() |
| 197 | + lib.start_instrumentation() |
| 198 | + try: |
| 199 | + return __codspeed_root_frame__() |
| 200 | + finally: |
| 201 | + # Ensure instrumentation is stopped even if the test failed |
| 202 | + lib.stop_instrumentation() |
| 203 | + uri = get_git_relative_uri(nodeid, config.rootpath) |
| 204 | + lib.dump_stats_at(uri.encode("ascii")) |
| 205 | + finally: |
| 206 | + # Ensure GC is re-enabled even if the test failed |
| 207 | + if is_gc_enabled: |
| 208 | + gc.enable() |
| 209 | + |
| 210 | + |
| 211 | +def wrap_runtest( |
| 212 | + lib: LibType, |
| 213 | + nodeid: str, |
| 214 | + config: pytest.Config, |
| 215 | + fn: Callable[P, T], |
| 216 | +) -> Callable[P, T]: |
| 217 | + @functools.wraps(fn) |
| 218 | + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: |
| 219 | + return _run_with_instrumentation(lib, nodeid, config, fn, *args, **kwargs) |
203 | 220 |
|
204 |
| - return result |
| 221 | + return wrapped |
205 | 222 |
|
206 | 223 |
|
207 | 224 | @pytest.hookimpl(tryfirst=True)
|
208 | 225 | def pytest_runtest_protocol(item: pytest.Item, nextitem: pytest.Item | None):
|
209 | 226 | plugin = get_plugin(item.config)
|
210 | 227 | if not plugin.is_codspeed_enabled or not should_benchmark_item(item):
|
211 |
| - return ( |
212 |
| - None # Defer to the default test protocol since no benchmarking is needed |
213 |
| - ) |
| 228 | + # Defer to the default test protocol since no benchmarking is needed |
| 229 | + return None |
214 | 230 |
|
215 | 231 | if has_benchmark_fixture(item):
|
216 |
| - return None # Instrumentation is handled by the fixture |
| 232 | + # Instrumentation is handled by the fixture |
| 233 | + return None |
217 | 234 |
|
218 | 235 | plugin.benchmark_count += 1
|
219 | 236 | if not plugin.should_measure:
|
220 |
| - return None # Benchmark counted but will be run in the default protocol |
221 |
| - |
222 |
| - # Setup phase |
223 |
| - reports = [] |
224 |
| - ihook = item.ihook |
225 |
| - ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location) |
226 |
| - setup_call = pytest.CallInfo.from_call( |
227 |
| - lambda: ihook.pytest_runtest_setup(item=item, nextitem=nextitem), "setup" |
228 |
| - ) |
229 |
| - setup_report = ihook.pytest_runtest_makereport(item=item, call=setup_call) |
230 |
| - ihook.pytest_runtest_logreport(report=setup_report) |
231 |
| - reports.append(setup_report) |
232 |
| - # Run phase |
233 |
| - if setup_report.passed and not item.config.getoption("setuponly"): |
234 |
| - assert plugin.lib is not None |
235 |
| - runtest_call = pytest.CallInfo.from_call( |
236 |
| - lambda: _run_with_instrumentation( |
237 |
| - plugin.lib, item.nodeid, item.config, item.runtest |
238 |
| - ), |
239 |
| - "call", |
240 |
| - ) |
241 |
| - runtest_report = ihook.pytest_runtest_makereport(item=item, call=runtest_call) |
242 |
| - ihook.pytest_runtest_logreport(report=runtest_report) |
243 |
| - reports.append(runtest_report) |
244 |
| - |
245 |
| - # Teardown phase |
246 |
| - teardown_call = pytest.CallInfo.from_call( |
247 |
| - lambda: ihook.pytest_runtest_teardown(item=item, nextitem=nextitem), "teardown" |
248 |
| - ) |
249 |
| - teardown_report = ihook.pytest_runtest_makereport(item=item, call=teardown_call) |
250 |
| - ihook.pytest_runtest_logreport(report=teardown_report) |
251 |
| - reports.append(teardown_report) |
252 |
| - ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location) |
| 237 | + # Benchmark counted but will be run in the default protocol |
| 238 | + return None |
253 | 239 |
|
254 |
| - return reports # Deny further protocol hooks execution |
| 240 | + # Wrap runtest and defer to default protocol |
| 241 | + assert plugin.lib is not None |
| 242 | + item.runtest = wrap_runtest(plugin.lib, item.nodeid, item.config, item.runtest) |
| 243 | + return None |
255 | 244 |
|
256 | 245 |
|
257 | 246 | class BenchmarkFixture:
|
|
0 commit comments