Skip to content

Commit 98a85a8

Browse files
authored
[dask] Drop aliases of core network parameters (#3843)
* Update dask.py * Update basic.py * hotfix pop
1 parent b7ccdaf commit 98a85a8

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

python-package/lightgbm/basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ class _ConfigAliases:
298298
"local_listen_port": {"local_listen_port",
299299
"local_port",
300300
"port"},
301+
"machine_list_filename": {"machine_list_filename",
302+
"machine_list_file",
303+
"machine_list",
304+
"mlist"},
301305
"machines": {"machines",
302306
"workers",
303307
"nodes"},
@@ -315,6 +319,8 @@ class _ConfigAliases:
315319
"num_rounds",
316320
"num_boost_round",
317321
"n_estimators"},
322+
"num_machines": {"num_machines",
323+
"num_machine"},
318324
"num_threads": {"num_threads",
319325
"num_thread",
320326
"nthread",

python-package/lightgbm/dask.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
230230
return part # trigger error locally
231231

232232
# Find locations of all parts and map them to particular Dask workers
233-
key_to_part_dict = dict([(part.key, part) for part in parts])
233+
key_to_part_dict = {part.key: part for part in parts}
234234
who_has = client.who_has(parts)
235235
worker_map = defaultdict(list)
236236
for key, workers in who_has.items():
@@ -280,6 +280,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
280280
for num_thread_alias in _ConfigAliases.get('num_threads'):
281281
params.pop(num_thread_alias, None)
282282

283+
# machines is constructed manually, so remove it and all aliases of it from params
284+
for machine_alias in _ConfigAliases.get('machines'):
285+
params.pop(machine_alias, None)
286+
287+
# machines is constructed manually, so remove machine_list_filename and all aliases of it from params
288+
for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'):
289+
params.pop(machine_list_filename_alias, None)
290+
291+
# machines is constructed manually, so remove num_machines and all aliases of it from params
292+
for num_machine_alias in _ConfigAliases.get('num_machines'):
293+
params.pop(num_machine_alias, None)
294+
283295
# Tell each worker to train on the parts that it has locally
284296
futures_classifiers = [
285297
client.submit(

0 commit comments

Comments
 (0)