Skip to content

Commit 269b454

Browse files
committed
pythongh-116738: Make grp module thread-safe
1 parent d447129 commit 269b454

File tree

5 files changed

+130
-35
lines changed

5 files changed

+130
-35
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import unittest
2+
3+
from threading import Thread, Barrier
4+
from test.support import threading_helper
5+
6+
7+
def run_concurrently(worker_func, args, nthreads):
8+
"""
9+
Run the worker function concurrently in multiple threads.
10+
"""
11+
barrier = Barrier(nthreads)
12+
13+
def wrapper_func(*args):
14+
# Wait for all threads to reach this point before proceeding.
15+
barrier.wait()
16+
worker_func(*args)
17+
18+
with threading_helper.catch_threading_exception() as cm:
19+
workers = (
20+
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
21+
)
22+
with threading_helper.start_threads(workers):
23+
pass
24+
25+
# If a worker thread raises an exception, re-raise it.
26+
if cm.exc_value is not None:
27+
raise cm.exc_value
28+
29+
30+
@threading_helper.requires_working_threading()
31+
class TestFTUtils(unittest.TestCase):
32+
def test_run_concurrently(self):
33+
lst = []
34+
35+
def worker(lst):
36+
lst.append(42)
37+
38+
nthreads = 10
39+
run_concurrently(worker, (lst,), nthreads)
40+
self.assertEqual(lst, [42] * nthreads)
41+
42+
def test_run_concurrently_raise(self):
43+
def worker():
44+
raise RuntimeError("Error")
45+
46+
nthreads = 3
47+
with self.assertRaises(RuntimeError):
48+
run_concurrently(worker, (), nthreads)
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
3+
from test.support import import_helper, threading_helper
4+
from test.test_free_threading.test_ft import run_concurrently
5+
6+
grp = import_helper.import_module("grp")
7+
8+
from test import test_grp
9+
10+
11+
NTHREADS = 10
12+
13+
14+
@threading_helper.requires_working_threading()
15+
class TestGrp(unittest.TestCase):
16+
def setUp(self):
17+
self.test_grp = test_grp.GroupDatabaseTestCase()
18+
19+
def test_racing_test_values(self):
20+
# test_grp.test_values() calls grp.getgrall() and checks the entries
21+
run_concurrently(
22+
worker_func=self.test_grp.test_values, args=(), nthreads=NTHREADS
23+
)
24+
25+
def test_racing_test_values_extended(self):
26+
# test_grp.test_values_extended() calls grp.getgrall(), grp.getgrgid(),
27+
# grp.getgrnam() and checks the entries
28+
run_concurrently(
29+
worker_func=self.test_grp.test_values_extended,
30+
args=(),
31+
nthreads=NTHREADS,
32+
)
33+
34+
35+
if __name__ == "__main__":
36+
unittest.main()

Lib/test/test_free_threading/test_heapq.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import heapq
44

55
from enum import Enum
6-
from threading import Thread, Barrier
76
from random import shuffle, randint
87

98
from test.support import threading_helper
9+
from test.test_free_threading.test_ft import run_concurrently
1010
from test import test_heapq
1111

1212

@@ -28,7 +28,7 @@ def test_racing_heapify(self):
2828
heap = list(range(OBJECT_COUNT))
2929
shuffle(heap)
3030

31-
self.run_concurrently(
31+
run_concurrently(
3232
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
3333
)
3434
self.test_heapq.check_invariant(heap)
@@ -40,7 +40,7 @@ def heappush_func(heap):
4040
for item in reversed(range(OBJECT_COUNT)):
4141
heapq.heappush(heap, item)
4242

43-
self.run_concurrently(
43+
run_concurrently(
4444
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
4545
)
4646
self.test_heapq.check_invariant(heap)
@@ -61,7 +61,7 @@ def heappop_func(heap, pop_count):
6161
# Each local list should be sorted
6262
self.assertTrue(self.is_sorted_ascending(local_list))
6363

64-
self.run_concurrently(
64+
run_concurrently(
6565
worker_func=heappop_func,
6666
args=(heap, per_thread_pop_count),
6767
nthreads=NTHREADS,
@@ -77,7 +77,7 @@ def heappushpop_func(heap, pushpop_items):
7777
popped_item = heapq.heappushpop(heap, item)
7878
self.assertTrue(popped_item <= item)
7979

80-
self.run_concurrently(
80+
run_concurrently(
8181
worker_func=heappushpop_func,
8282
args=(heap, pushpop_items),
8383
nthreads=NTHREADS,
@@ -93,7 +93,7 @@ def heapreplace_func(heap, replace_items):
9393
for item in replace_items:
9494
heapq.heapreplace(heap, item)
9595

96-
self.run_concurrently(
96+
run_concurrently(
9797
worker_func=heapreplace_func,
9898
args=(heap, replace_items),
9999
nthreads=NTHREADS,
@@ -105,7 +105,7 @@ def test_racing_heapify_max(self):
105105
max_heap = list(range(OBJECT_COUNT))
106106
shuffle(max_heap)
107107

108-
self.run_concurrently(
108+
run_concurrently(
109109
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
110110
)
111111
self.test_heapq.check_max_invariant(max_heap)
@@ -117,7 +117,7 @@ def heappush_max_func(max_heap):
117117
for item in range(OBJECT_COUNT):
118118
heapq.heappush_max(max_heap, item)
119119

120-
self.run_concurrently(
120+
run_concurrently(
121121
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
122122
)
123123
self.test_heapq.check_max_invariant(max_heap)
@@ -138,7 +138,7 @@ def heappop_max_func(max_heap, pop_count):
138138
# Each local list should be sorted
139139
self.assertTrue(self.is_sorted_descending(local_list))
140140

141-
self.run_concurrently(
141+
run_concurrently(
142142
worker_func=heappop_max_func,
143143
args=(max_heap, per_thread_pop_count),
144144
nthreads=NTHREADS,
@@ -154,7 +154,7 @@ def heappushpop_max_func(max_heap, pushpop_items):
154154
popped_item = heapq.heappushpop_max(max_heap, item)
155155
self.assertTrue(popped_item >= item)
156156

157-
self.run_concurrently(
157+
run_concurrently(
158158
worker_func=heappushpop_max_func,
159159
args=(max_heap, pushpop_items),
160160
nthreads=NTHREADS,
@@ -170,7 +170,7 @@ def heapreplace_max_func(max_heap, replace_items):
170170
for item in replace_items:
171171
heapq.heapreplace_max(max_heap, item)
172172

173-
self.run_concurrently(
173+
run_concurrently(
174174
worker_func=heapreplace_max_func,
175175
args=(max_heap, replace_items),
176176
nthreads=NTHREADS,
@@ -214,27 +214,6 @@ def create_random_list(a, b, size):
214214
"""
215215
return [randint(-a, b) for _ in range(size)]
216216

217-
def run_concurrently(self, worker_func, args, nthreads):
218-
"""
219-
Run the worker function concurrently in multiple threads.
220-
"""
221-
barrier = Barrier(nthreads)
222-
223-
def wrapper_func(*args):
224-
# Wait for all threads to reach this point before proceeding.
225-
barrier.wait()
226-
worker_func(*args)
227-
228-
with threading_helper.catch_threading_exception() as cm:
229-
workers = (
230-
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
231-
)
232-
with threading_helper.start_threads(workers):
233-
pass
234-
235-
# Worker threads should not raise any exceptions
236-
self.assertIsNone(cm.exc_value)
237-
238217

239218
if __name__ == "__main__":
240219
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make methods in :mod:`grp` thread-safe on the :term:`free threaded <free threading>` build.

Modules/grpmodule.c

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
132132
if (!_Py_Gid_Converter(id, &gid)) {
133133
return NULL;
134134
}
135-
#ifdef HAVE_GETGRGID_R
135+
#if defined(HAVE_GETGRGID_R)
136136
int status;
137137
Py_ssize_t bufsize;
138138
/* Note: 'grp' will be used via pointer 'p' on getgrgid_r success. */
@@ -167,6 +167,17 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
167167
}
168168

169169
Py_END_ALLOW_THREADS
170+
#elif defined(Py_GIL_DISABLED)
171+
static PyMutex getgrgid_mutex = {0};
172+
PyMutex_Lock(&getgrgid_mutex);
173+
// The getgrgid() function need not be thread-safe.
174+
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrgid.html
175+
p = getgrgid(gid);
176+
if (p == NULL) {
177+
// Unlock the mutex on error. The following error handling block will
178+
// handle the rest.
179+
PyMutex_Unlock(&getgrgid_mutex);
180+
}
170181
#else
171182
p = getgrgid(gid);
172183
#endif
@@ -183,8 +194,10 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
183194
return NULL;
184195
}
185196
retval = mkgrent(module, p);
186-
#ifdef HAVE_GETGRGID_R
197+
#if defined(HAVE_GETGRGID_R)
187198
PyMem_RawFree(buf);
199+
#elif defined(Py_GIL_DISABLED)
200+
PyMutex_Unlock(&getgrgid_mutex);
188201
#endif
189202
return retval;
190203
}
@@ -213,7 +226,7 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
213226
/* check for embedded null bytes */
214227
if (PyBytes_AsStringAndSize(bytes, &name_chars, NULL) == -1)
215228
goto out;
216-
#ifdef HAVE_GETGRNAM_R
229+
#if defined(HAVE_GETGRNAM_R)
217230
int status;
218231
Py_ssize_t bufsize;
219232
/* Note: 'grp' will be used via pointer 'p' on getgrnam_r success. */
@@ -248,6 +261,17 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
248261
}
249262

250263
Py_END_ALLOW_THREADS
264+
#elif defined(Py_GIL_DISABLED)
265+
static PyMutex getgrnam_mutex = {0};
266+
PyMutex_Lock(&getgrnam_mutex);
267+
// The getgrnam() function need not be thread-safe.
268+
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrnam.html
269+
p = getgrnam(name_chars);
270+
if (p == NULL) {
271+
// Unlock the mutex on error. The following error handling block will
272+
// handle the rest.
273+
PyMutex_Unlock(&getgrnam_mutex);
274+
}
251275
#else
252276
p = getgrnam(name_chars);
253277
#endif
@@ -261,6 +285,9 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
261285
goto out;
262286
}
263287
retval = mkgrent(module, p);
288+
#if !defined(HAVE_GETGRNAM_R) && defined(Py_GIL_DISABLED)
289+
PyMutex_Unlock(&getgrnam_mutex);
290+
#endif
264291
out:
265292
PyMem_RawFree(buf);
266293
Py_DECREF(bytes);

0 commit comments

Comments
 (0)