Skip to content

Commit 84ba8e2

Browse files
committed
pythongh-116738: Make _heapq module thread-safe
1 parent ac75110 commit 84ba8e2

File tree

3 files changed

+321
-11
lines changed

3 files changed

+321
-11
lines changed
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import unittest
2+
3+
import heapq
4+
import operator
5+
6+
from enum import Enum
7+
from threading import Thread, Barrier
8+
from random import shuffle, randint
9+
10+
from test.support import threading_helper
11+
12+
13+
NTHREADS: int = 10
14+
OBJECT_COUNT: int = 5_000
15+
16+
17+
class HeapKind(Enum):
18+
MIN = 1
19+
MAX = 2
20+
21+
22+
@threading_helper.requires_working_threading()
23+
class TestHeapq(unittest.TestCase):
24+
def test_racing_heapify(self):
25+
heap = list(range(OBJECT_COUNT))
26+
shuffle(heap)
27+
28+
def heapify_func(heap: list[int]):
29+
heapq.heapify(heap)
30+
31+
self.run_concurrently(
32+
worker_func=heapify_func, args=(heap,), nthreads=NTHREADS
33+
)
34+
self.assertTrue(self.is_min_heap_property_satisfied(heap))
35+
36+
def test_racing_heappush(self):
37+
heap = []
38+
39+
def heappush_func(heap: list[int]):
40+
for item in reversed(range(OBJECT_COUNT)):
41+
heapq.heappush(heap, item)
42+
43+
self.run_concurrently(
44+
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
45+
)
46+
self.assertTrue(self.is_min_heap_property_satisfied(heap))
47+
48+
def test_racing_heappop(self):
49+
heap = list(range(OBJECT_COUNT))
50+
shuffle(heap)
51+
heapq.heapify(heap)
52+
53+
# Each thread pops (OBJECT_COUNT / NTHREADS) items
54+
self.assertEqual(0, OBJECT_COUNT % NTHREADS)
55+
per_thread_pop_count = OBJECT_COUNT // NTHREADS
56+
57+
def heappop_func(heap: list[int], pop_count: int):
58+
local_list = []
59+
for _ in range(pop_count):
60+
item = heapq.heappop(heap)
61+
local_list.append(item)
62+
63+
# Each local list should be sorted
64+
self.assertTrue(self.is_sorted_ascending(local_list))
65+
66+
self.run_concurrently(
67+
worker_func=heappop_func,
68+
args=(heap, per_thread_pop_count),
69+
nthreads=NTHREADS,
70+
)
71+
self.assertEqual(0, len(heap))
72+
73+
def test_racing_heappushpop(self):
74+
heap = list(range(OBJECT_COUNT))
75+
shuffle(heap)
76+
heapq.heapify(heap)
77+
78+
pushpop_items = [
79+
randint(-OBJECT_COUNT, OBJECT_COUNT) for _ in range(OBJECT_COUNT)
80+
]
81+
82+
def heappushpop_func(heap: list[int], pushpop_items: list[int]):
83+
for item in pushpop_items:
84+
popped_item = heapq.heappushpop(heap, item)
85+
self.assertTrue(popped_item <= item)
86+
87+
self.run_concurrently(
88+
worker_func=heappushpop_func,
89+
args=(heap, pushpop_items),
90+
nthreads=NTHREADS,
91+
)
92+
self.assertEqual(OBJECT_COUNT, len(heap))
93+
self.assertTrue(self.is_min_heap_property_satisfied(heap))
94+
95+
def test_racing_heapreplace(self):
96+
heap = list(range(OBJECT_COUNT))
97+
shuffle(heap)
98+
heapq.heapify(heap)
99+
100+
replace_items = [
101+
randint(-OBJECT_COUNT, OBJECT_COUNT) for _ in range(OBJECT_COUNT)
102+
]
103+
104+
def heapreplace_func(heap: list[int], replace_items: list[int]):
105+
for item in replace_items:
106+
popped_item = heapq.heapreplace(heap, item)
107+
108+
self.run_concurrently(
109+
worker_func=heapreplace_func,
110+
args=(heap, replace_items),
111+
nthreads=NTHREADS,
112+
)
113+
self.assertEqual(OBJECT_COUNT, len(heap))
114+
self.assertTrue(self.is_min_heap_property_satisfied(heap))
115+
116+
def test_racing_heapify_max(self):
117+
max_heap = list(range(OBJECT_COUNT))
118+
shuffle(max_heap)
119+
120+
def heapify_max_func(max_heap: list[int]):
121+
heapq.heapify_max(max_heap)
122+
123+
self.run_concurrently(
124+
worker_func=heapify_max_func, args=(max_heap,), nthreads=NTHREADS
125+
)
126+
self.assertTrue(self.is_max_heap_property_satisfied(max_heap))
127+
128+
def test_racing_heappush_max(self):
129+
max_heap = []
130+
131+
def heappush_max_func(max_heap: list[int]):
132+
for item in range(OBJECT_COUNT):
133+
heapq.heappush_max(max_heap, item)
134+
135+
self.run_concurrently(
136+
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
137+
)
138+
self.assertTrue(self.is_max_heap_property_satisfied(max_heap))
139+
140+
def test_racing_heappop_max(self):
141+
max_heap = list(range(OBJECT_COUNT))
142+
shuffle(max_heap)
143+
heapq.heapify_max(max_heap)
144+
145+
# Each thread pops (OBJECT_COUNT / NTHREADS) items
146+
self.assertEqual(0, OBJECT_COUNT % NTHREADS)
147+
per_thread_pop_count = OBJECT_COUNT // NTHREADS
148+
149+
def heappop_max_func(max_heap: list[int], pop_count: int):
150+
local_list = []
151+
for _ in range(pop_count):
152+
item = heapq.heappop_max(max_heap)
153+
local_list.append(item)
154+
155+
# Each local list should be sorted
156+
self.assertTrue(self.is_sorted_descending(local_list))
157+
158+
self.run_concurrently(
159+
worker_func=heappop_max_func,
160+
args=(max_heap, per_thread_pop_count),
161+
nthreads=NTHREADS,
162+
)
163+
self.assertEqual(0, len(max_heap))
164+
165+
def test_racing_heappushpop_max(self):
166+
max_heap = list(range(OBJECT_COUNT))
167+
shuffle(max_heap)
168+
heapq.heapify_max(max_heap)
169+
170+
pushpop_items = [
171+
randint(-OBJECT_COUNT, OBJECT_COUNT) for _ in range(OBJECT_COUNT)
172+
]
173+
174+
def heappushpop_max_func(
175+
max_heap: list[int], pushpop_items: list[int]
176+
):
177+
for item in pushpop_items:
178+
popped_item = heapq.heappushpop_max(max_heap, item)
179+
self.assertTrue(popped_item >= item)
180+
181+
self.run_concurrently(
182+
worker_func=heappushpop_max_func,
183+
args=(max_heap, pushpop_items),
184+
nthreads=NTHREADS,
185+
)
186+
self.assertEqual(OBJECT_COUNT, len(max_heap))
187+
self.assertTrue(self.is_max_heap_property_satisfied(max_heap))
188+
189+
def test_racing_heapreplace_max(self):
190+
max_heap = list(range(OBJECT_COUNT))
191+
shuffle(max_heap)
192+
heapq.heapify_max(max_heap)
193+
194+
replace_items = [
195+
randint(-OBJECT_COUNT, OBJECT_COUNT) for _ in range(OBJECT_COUNT)
196+
]
197+
198+
def heapreplace_max_func(
199+
max_heap: list[int], replace_items: list[int]
200+
):
201+
for item in replace_items:
202+
popped_item = heapq.heapreplace_max(max_heap, item)
203+
204+
self.run_concurrently(
205+
worker_func=heapreplace_max_func,
206+
args=(max_heap, replace_items),
207+
nthreads=NTHREADS,
208+
)
209+
self.assertEqual(OBJECT_COUNT, len(max_heap))
210+
self.assertTrue(self.is_max_heap_property_satisfied(max_heap))
211+
212+
def is_min_heap_property_satisfied(self, heap: list[object]) -> bool:
213+
"""
214+
The value of a parent node should be less than or equal to the
215+
values of its children.
216+
"""
217+
return self.is_heap_property_satisfied(heap, HeapKind.MIN)
218+
219+
def is_max_heap_property_satisfied(self, heap: list[object]) -> bool:
220+
"""
221+
The value of a parent node should be greater than or equal to the
222+
values of its children.
223+
"""
224+
return self.is_heap_property_satisfied(heap, HeapKind.MAX)
225+
226+
@staticmethod
227+
def is_heap_property_satisfied(
228+
heap: list[object], heap_kind: HeapKind
229+
) -> bool:
230+
"""
231+
Check if the heap property is satisfied.
232+
"""
233+
op = operator.le if heap_kind == HeapKind.MIN else operator.ge
234+
# position 0 has no parent
235+
for pos in range(1, len(heap)):
236+
parent_pos = (pos - 1) >> 1
237+
if not op(heap[parent_pos], heap[pos]):
238+
return False
239+
240+
return True
241+
242+
@staticmethod
243+
def is_sorted_ascending(lst: list[object]) -> bool:
244+
"""
245+
Check if the list is sorted in ascending order (non-decreasing).
246+
"""
247+
return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
248+
249+
@staticmethod
250+
def is_sorted_descending(lst: list[object]) -> bool:
251+
"""
252+
Check if the list is sorted in descending order (non-increasing).
253+
"""
254+
return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
255+
256+
@staticmethod
257+
def run_concurrently(worker_func, args, nthreads) -> None:
258+
"""
259+
Run the worker function concurrently in multiple threads.
260+
"""
261+
barrier = Barrier(NTHREADS)
262+
263+
def wrapper_func(*args):
264+
# Wait for all threadss to reach this point before proceeding.
265+
barrier.wait()
266+
worker_func(*args)
267+
268+
workers = []
269+
for _ in range(nthreads):
270+
worker = Thread(target=wrapper_func, args=args)
271+
workers.append(worker)
272+
worker.start()
273+
274+
for worker in workers:
275+
worker.join()
276+
277+
278+
if __name__ == "__main__":
279+
unittest.main()

0 commit comments

Comments
 (0)