Skip to content

Commit 296397d

Browse files
authored
[dask] raise more informative error for duplicates in 'machines' (fixes #4057) (#4059)
* [dask] raise more informative error for duplicates in 'machines' * uncomment * avoid test failure * Revert "avoid test failure" This reverts commit 9442bdf.
1 parent b75a43a commit 296397d

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

python-package/lightgbm/dask.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
153153
Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
154154
"""
155155
machine_addresses = machines.split(",")
156+
157+
if len(set(machine_addresses)) != len(machine_addresses):
158+
raise ValueError(f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination.")
159+
156160
machine_to_port = defaultdict(set)
157161
for address in machine_addresses:
158162
host, port = address.split(":")

tests/python_package_test/test_dask.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,7 @@ def test_machines_should_be_used_if_provided(task, output):
11161116
client.rebalance()
11171117

11181118
n_workers = len(client.scheduler_info()['workers'])
1119+
assert n_workers > 1
11191120
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
11201121
dask_model = dask_model_factory(
11211122
n_estimators=5,
@@ -1134,6 +1135,17 @@ def test_machines_should_be_used_if_provided(task, output):
11341135
s.bind(('127.0.0.1', open_ports[0]))
11351136
dask_model.fit(dX, dy, group=dg)
11361137

1138+
# an informative error should be raised if "machines" has duplicates
1139+
one_open_port = lgb.dask._find_random_open_port()
1140+
dask_model.set_params(
1141+
machines=",".join([
1142+
"127.0.0.1:" + str(one_open_port)
1143+
for _ in range(n_workers)
1144+
])
1145+
)
1146+
with pytest.raises(ValueError, match="Found duplicates in 'machines'"):
1147+
dask_model.fit(dX, dy, group=dg)
1148+
11371149

11381150
@pytest.mark.parametrize(
11391151
"classes",

0 commit comments

Comments
 (0)