Skip to content

Commit ffa52c3

Browse files
committed
Add trainium parameter to @kubernetes decorator
1 parent 8e29f77 commit ffa52c3

File tree

7 files changed

+80
-2
lines changed

7 files changed

+80
-2
lines changed

metaflow/plugins/airflow/airflow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,11 @@ def _to_job(self, node):
449449
# Don't set GPU limits if gpu isn't specified.
450450
if k8s_deco.attributes["gpu"] is not None
451451
},
452+
**{
453+
"aws.amazon.com/neuron": str(k8s_deco.attributes["trainium"])
454+
for k in [0]
455+
if k8s_deco.attributes.get("trainium") is not None
456+
},
452457
},
453458
)
454459

metaflow/plugins/argo/argo_workflows.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2628,6 +2628,7 @@ def _container_templates(self):
26282628
disk=str(resources["disk"]),
26292629
gpu=resources["gpu"],
26302630
gpu_vendor=str(resources["gpu_vendor"]),
2631+
trainium=resources.get("trainium"),
26312632
tolerations=resources["tolerations"],
26322633
use_tmpfs=use_tmpfs,
26332634
tmpfs_tempdir=tmpfs_tempdir,
@@ -2866,6 +2867,13 @@ def _container_templates(self):
28662867
for k in [0]
28672868
if resources["gpu"] is not None
28682869
},
2870+
**{
2871+
"aws.amazon.com/neuron": str(
2872+
resources["trainium"]
2873+
)
2874+
for k in [0]
2875+
if resources.get("trainium") is not None
2876+
},
28692877
},
28702878
),
28712879
# Configure secrets

metaflow/plugins/kubernetes/kubernetes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def create_jobset(
181181
cpu=None,
182182
gpu=None,
183183
gpu_vendor=None,
184+
trainium=None,
184185
disk=None,
185186
memory=None,
186187
use_tmpfs=None,
@@ -215,6 +216,7 @@ def create_jobset(
215216
disk=disk,
216217
gpu=gpu,
217218
gpu_vendor=gpu_vendor,
219+
trainium=trainium,
218220
timeout_in_seconds=run_time_limit,
219221
# Retries are handled by Metaflow runtime
220222
retries=0,
@@ -482,6 +484,7 @@ def create_job_object(
482484
cpu=None,
483485
gpu=None,
484486
gpu_vendor=None,
487+
trainium=None,
485488
disk=None,
486489
memory=None,
487490
use_tmpfs=None,
@@ -528,6 +531,7 @@ def create_job_object(
528531
disk=disk,
529532
gpu=gpu,
530533
gpu_vendor=gpu_vendor,
534+
trainium=trainium,
531535
timeout_in_seconds=run_time_limit,
532536
# Retries are handled by Metaflow runtime
533537
retries=0,

metaflow/plugins/kubernetes/kubernetes_cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def kubernetes():
8989
@click.option("--memory", help="Memory requirement for Kubernetes pod.")
9090
@click.option("--gpu", help="GPU requirement for Kubernetes pod.")
9191
@click.option("--gpu-vendor", help="GPU vendor requirement for Kubernetes pod.")
92+
@click.option("--trainium", help="AWS Trainium/Inferentia Neuron device requirement for Kubernetes pod.")
9293
@click.option("--run-id", help="Passed to the top-level 'step'.")
9394
@click.option("--task-id", help="Passed to the top-level 'step'.")
9495
@click.option("--input-paths", help="Passed to the top-level 'step'.")
@@ -178,6 +179,7 @@ def step(
178179
memory=None,
179180
gpu=None,
180181
gpu_vendor=None,
182+
trainium=None,
181183
use_tmpfs=None,
182184
tmpfs_tempdir=None,
183185
tmpfs_size=None,
@@ -323,6 +325,7 @@ def _sync_metadata():
323325
memory=memory,
324326
gpu=gpu,
325327
gpu_vendor=gpu_vendor,
328+
trainium=trainium,
326329
use_tmpfs=use_tmpfs,
327330
tmpfs_tempdir=tmpfs_tempdir,
328331
tmpfs_size=tmpfs_size,

metaflow/plugins/kubernetes/kubernetes_decorator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class KubernetesDecorator(StepDecorator):
152152
"namespace": None,
153153
"gpu": None, # value of 0 implies that the scheduled node should not have GPUs
154154
"gpu_vendor": None,
155+
"trainium": None, # number of AWS Trainium/Inferentia Neuron devices
155156
"tolerations": None, # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"},
156157
# {"key": "foo", "operator": "Equal", "value": "bar"}]
157158
"labels": None, # e.g. {"test-label": "value", "another-label":"value2"}
@@ -382,6 +383,17 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
382383
max(float(my_val or 0), float(v or 0))
383384
)
384385

386+
# Validate mutually exclusive: gpu and trainium cannot both be set.
387+
if (
388+
self.attributes["trainium"] is not None
389+
and self.attributes["gpu"] is not None
390+
):
391+
raise KubernetesException(
392+
"Cannot specify both 'gpu' and 'trainium' for step *{step}*.".format(
393+
step=step
394+
)
395+
)
396+
385397
# Check GPU vendor.
386398
if self.attributes["gpu_vendor"].lower() not in ("amd", "nvidia"):
387399
raise KubernetesException(
@@ -412,6 +424,16 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
412424
)
413425
)
414426

427+
if self.attributes["trainium"] is not None and not (
428+
isinstance(self.attributes["trainium"], (int, unicode, basestring))
429+
and float(self.attributes["trainium"]).is_integer()
430+
):
431+
raise KubernetesException(
432+
"Invalid trainium value *{}* for step *{step}*; it should be an integer".format(
433+
self.attributes["trainium"], step=step
434+
)
435+
)
436+
415437
if self.attributes["tmpfs_size"]:
416438
if not (
417439
isinstance(self.attributes["tmpfs_size"], (int, unicode, basestring))

metaflow/plugins/kubernetes/kubernetes_job.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,13 @@ def create_job_spec(self):
182182
# Don't set GPU limits if gpu isn't specified.
183183
if self._kwargs["gpu"] is not None
184184
},
185+
**{
186+
"aws.amazon.com/neuron": str(
187+
self._kwargs["trainium"]
188+
)
189+
for k in [0]
190+
if self._kwargs.get("trainium") is not None
191+
},
185192
},
186193
),
187194
volume_mounts=(
@@ -236,7 +243,18 @@ def create_job_spec(self):
236243
tolerations=[
237244
client.V1Toleration(**toleration)
238245
for toleration in self._kwargs.get("tolerations") or []
239-
],
246+
]
247+
+ (
248+
[
249+
client.V1Toleration(
250+
key="aws.amazon.com/neuron",
251+
operator="Exists",
252+
effect="NoSchedule",
253+
)
254+
]
255+
if self._kwargs.get("trainium") is not None
256+
else []
257+
),
240258
volumes=(
241259
[
242260
client.V1Volume(

metaflow/plugins/kubernetes/kubernetes_jobsets.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ def dump(self):
679679
# Don't set GPU limits if gpu isn't specified.
680680
if self._kwargs["gpu"] is not None
681681
},
682+
**{
683+
"aws.amazon.com/neuron": str(
684+
self._kwargs["trainium"]
685+
)
686+
for k in [0]
687+
if self._kwargs.get("trainium") is not None
688+
},
682689
},
683690
),
684691
volume_mounts=(
@@ -740,7 +747,18 @@ def dump(self):
740747
client.V1Toleration(**toleration)
741748
for toleration in self._kwargs.get("tolerations")
742749
or []
743-
],
750+
]
751+
+ (
752+
[
753+
client.V1Toleration(
754+
key="aws.amazon.com/neuron",
755+
operator="Exists",
756+
effect="NoSchedule",
757+
)
758+
]
759+
if self._kwargs.get("trainium") is not None
760+
else []
761+
),
744762
volumes=(
745763
[
746764
client.V1Volume(

0 commit comments

Comments
 (0)