@@ -152,6 +152,7 @@ class KubernetesDecorator(StepDecorator):
152152 "namespace" : None ,
153153 "gpu" : None , # value of 0 implies that the scheduled node should not have GPUs
154154 "gpu_vendor" : None ,
155+ "trainium" : None , # number of AWS Trainium/Inferentia Neuron devices
155156 "tolerations" : None , # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"},
156157 # {"key": "foo", "operator": "Equal", "value": "bar"}]
157158 "labels" : None , # e.g. {"test-label": "value", "another-label":"value2"}
@@ -382,6 +383,17 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
382383 max (float (my_val or 0 ), float (v or 0 ))
383384 )
384385
386+ # Validate mutually exclusive: gpu and trainium cannot both be set.
387+ if (
388+ self .attributes ["trainium" ] is not None
389+ and self .attributes ["gpu" ] is not None
390+ ):
391+ raise KubernetesException (
392+ "Cannot specify both 'gpu' and 'trainium' for step *{step}*." .format (
393+ step = step
394+ )
395+ )
396+
385397 # Check GPU vendor.
386398 if self .attributes ["gpu_vendor" ].lower () not in ("amd" , "nvidia" ):
387399 raise KubernetesException (
@@ -412,6 +424,16 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
412424 )
413425 )
414426
427+ if self .attributes ["trainium" ] is not None and not (
428+ isinstance (self .attributes ["trainium" ], (int , unicode , basestring ))
429+ and float (self .attributes ["trainium" ]).is_integer ()
430+ ):
431+ raise KubernetesException (
432+ "Invalid trainium value *{}* for step *{step}*; it should be an integer" .format (
433+ self .attributes ["trainium" ], step = step
434+ )
435+ )
436+
415437 if self .attributes ["tmpfs_size" ]:
416438 if not (
417439 isinstance (self .attributes ["tmpfs_size" ], (int , unicode , basestring ))
0 commit comments