Skip to content

Add trainium parameter to @kubernetes decorator#3086

Draft
emattia wants to merge 3 commits intoNetflix:masterfrom
emattia:trn-k8s
Draft

Add trainium parameter to @kubernetes decorator#3086
emattia wants to merge 3 commits intoNetflix:masterfrom
emattia:trn-k8s

Conversation

@emattia
Copy link
Copy Markdown
Contributor

@emattia emattia commented Apr 8, 2026

PR Type

  • Bug fix
  • New feature
  • Core Runtime change (higher bar -- see CONTRIBUTING.md)
  • Docs / tooling
  • Refactoring

Summary

Users will be able to specify @kubernetes(neuron=X) in the same way as @batch UX currently exists. Note, this PR should be paired with support for @kubernetes(efa=X) which will land in another PR.

@emattia emattia marked this pull request as draft April 8, 2026 00:32
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Greptile Summary

This PR adds a trainium parameter to @kubernetes, mirroring the existing @batch UX, so users can request AWS Trainium/Inferentia Neuron devices via @kubernetes(trainium=X). The core implementation in kubernetes_job.py and kubernetes_jobsets.py is correct — both the aws.amazon.com/neuron resource limit and the automatic aws.amazon.com/neuron:NoSchedule toleration are injected. However, the same toleration is missing from two execution paths:

  • Argo Workflows (non-parallel path): argo_workflows.py adds the Neuron resource limit but calls .tolerations(resources.get(\"tolerations\")) without appending the Neuron toleration, so pods will stay Pending on tainted Neuron nodes.
  • Airflow: airflow.py adds the Neuron resource limit to the resources dict but never passes a matching toleration to KubernetesPodOperator, causing the same scheduling failure.

Confidence Score: 4/5

Two execution paths (Argo Workflows non-parallel and Airflow) will silently fail to schedule on Neuron nodes because the required toleration is not injected; the core Kubernetes job/jobset paths work correctly.

Two confirmed P1 bugs mean pods will not schedule on Trainium nodes when using Argo Workflows non-parallel path or Airflow. The core kubernetes_job and kubernetes_jobsets paths are correct. Score is 4 rather than lower because the feature works for the primary direct-Kubernetes execution path.

metaflow/plugins/argo/argo_workflows.py and metaflow/plugins/airflow/airflow.py need the automatic Neuron toleration added to their pod specs.

Vulnerabilities

No security concerns identified. The change adds a resource-limit annotation and a toleration to Kubernetes pod specs; there is no user-controlled input reaching a sensitive API without validation.

Important Files Changed

Filename Overview
metaflow/plugins/kubernetes/kubernetes_decorator.py Adds trainium as a new decorator attribute with mutual-exclusion check against gpu, integer validation, and CLI forwarding — correct, though the validator allows trainium=0 which would spuriously add a Neuron toleration.
metaflow/plugins/kubernetes/kubernetes_job.py Correctly adds aws.amazon.com/neuron resource limit and automatically injects the aws.amazon.com/neuron:NoSchedule toleration when trainium is set.
metaflow/plugins/kubernetes/kubernetes_jobsets.py Correctly adds aws.amazon.com/neuron resource limit and auto-injects the Neuron toleration for the parallel JobSet path, consistent with kubernetes_job.py.
metaflow/plugins/argo/argo_workflows.py Adds Neuron resource limit to the non-parallel pod spec and threads trainium through to the JobSet path, but the non-parallel path omits the required aws.amazon.com/neuron:NoSchedule toleration — pods will fail to schedule on Neuron nodes.
metaflow/plugins/airflow/airflow.py Adds Neuron resource limit to the resources dict but never adds a matching aws.amazon.com/neuron:NoSchedule toleration to the Airflow operator args, so pods will remain pending on tainted Neuron nodes.
metaflow/plugins/kubernetes/kubernetes_cli.py Adds --trainium CLI option and threads it through to the step command correctly.
metaflow/plugins/kubernetes/kubernetes.py Adds trainium parameter to both create_job and create_jobset methods and forwards it to the job/jobset constructors correctly.

Comments Outside Diff (1)

  1. metaflow/plugins/argo/argo_workflows.py, line 2796 (link)

    P1 Missing automatic Neuron toleration in non-parallel Argo Workflows path

    The non-parallel (non-JobSet) Argo Workflows pod spec adds the aws.amazon.com/neuron resource limit (line ~2871) but does not inject the corresponding aws.amazon.com/neuron:NoSchedule toleration. Trainium/Inferentia nodes carry that taint by default, so any pod that reaches this code path with trainium=N will remain in Pending state — it will never be scheduled.

    The JobSet path correctly auto-injects the toleration (via kubernetes_jobsets.py), and kubernetes_job.py does the same. The fix is to extend the toleration list here analogously:

    .tolerations(
        (resources.get("tolerations") or [])
        + (
            [{"key": "aws.amazon.com/neuron", "operator": "Exists", "effect": "NoSchedule"}]
            if resources.get("trainium") is not None
            else []
        )
    )

Reviews (1): Last reviewed commit: "Add trainium parameter to @kubernetes de..." | Re-trigger Greptile

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

Welcome to Codecov 🎉

Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests.

Thanks for integrating Codecov - We've got you covered ☂️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant