diff --git a/examples/activation/flake.nix b/examples/activation/flake.nix index cc1fe5c..8ddaf1b 100644 --- a/examples/activation/flake.nix +++ b/examples/activation/flake.nix @@ -14,5 +14,7 @@ kernel-builder.lib.genFlakeOutputs { path = ./.; rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + # Example of adding Python test dependencies directly in the flake + pythonTestDeps = [ "pytest-benchmark" ]; }; } diff --git a/flake.nix b/flake.nix index 3bb51d9..147493f 100644 --- a/flake.nix +++ b/flake.nix @@ -54,7 +54,13 @@ in builtins.toJSON (nixpkgs.lib.foldl' (acc: system: acc // buildVariants system) { } systems); genFlakeOutputs = - { path, rev }: + { + path, + rev, + pythonTestDeps ? [ ], + pythonDevDeps ? [ ], + customPythonPackages ? { }, + }: flake-utils.lib.eachSystem systems ( system: let @@ -71,10 +77,14 @@ devShells = build.torchDevShells { inherit path; rev = revUnderscored; + extraPythonPackages = pythonDevDeps; + inherit customPythonPackages; }; testShells = build.torchExtensionShells { inherit path; rev = revUnderscored; + extraPythonPackages = pythonTestDeps; + inherit customPythonPackages; }; }; packages = rec { @@ -182,5 +192,23 @@ ) // { inherit lib; + # Export the helper function + mkShellsWithExtraPackages = + { + pythonTestDeps ? [ ], + pythonDevDeps ? [ ], + }: + { + genFlakeOutputs = + { path, rev }: + lib.genFlakeOutputs { + inherit + path + rev + pythonTestDeps + pythonDevDeps + ; + }; + }; }; } diff --git a/lib/build.nix b/lib/build.nix index 730ff3d..1907383 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -195,21 +195,56 @@ rec { # Get a development shell with the extension in PYTHONPATH. Handy # for running tests. torchExtensionShells = - { path, rev }: + { + path, + rev, + extraPythonPackages ? [ ], + customPythonPackages ? { }, + }: let + buildConfig = readBuildConfig path; + shellForBuildSet = { path, rev }: buildSet: { name = torchBuildVersion buildSet; value = with buildSet.pkgs; + let + # Function to resolve nixpkgs packages or custom packages + resolvePythonPackage = + name: + if builtins.hasAttr name customPythonPackages then + python3.pkgs.buildPythonPackage { + pname = name; + version = "custom"; + src = customPythonPackages.${name}; + doCheck = false; + nativeBuildInputs = [ + python3.pkgs.setuptools + python3.pkgs.wheel + ]; + } + else if lib.hasPrefix "git+" name then + throw "ERROR: Git packages like '${name}' require flake inputs. Add as input: ${name} = { url = \"${lib.removePrefix "git+" name}\"; flake = false; } then use customPythonPackages = { ${name} = ${name}; }." + else if builtins.hasAttr name python3.pkgs then + python3.pkgs.${name} + else + throw "Python package '${name}' not found in nixpkgs or customPythonPackages"; + + # Resolve all packages + allPackages = map resolvePythonPackage extraPythonPackages; + in mkShell { buildInputs = [ (python3.withPackages ( - ps: with ps; [ + ps: + with ps; + [ buildSet.torch pytest ] + ++ allPackages )) ]; shellHook = '' @@ -217,12 +252,17 @@ rec { ''; }; }; - filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets; + filteredBuildSets = applicableBuildSets buildConfig buildSets; in builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) filteredBuildSets); torchDevShells = - { path, rev }: + { + path, + rev, + extraPythonPackages ? [ ], + customPythonPackages ? { }, + }: let shellForBuildSet = buildSet: @@ -239,7 +279,35 @@ rec { build2cmake kernel-abi-check ]; - buildInputs = with pkgs; [ python3.pkgs.pytest ]; + buildInputs = with pkgs; [ + (python3.withPackages ( + ps: + with ps; + [ + pytest + ] + ++ (map ( + name: + if builtins.hasAttr name customPythonPackages then + python3.pkgs.buildPythonPackage { + pname = name; + version = "custom"; + src = customPythonPackages.${name}; + doCheck = false; + nativeBuildInputs = [ + python3.pkgs.setuptools + python3.pkgs.wheel + ]; + } + else if lib.hasPrefix "git+" name then + throw "ERROR: Git packages like '${name}' require flake inputs. Add as input: ${name} = { url = \"${lib.removePrefix "git+" name}\"; flake = false; } then use customPythonPackages = { ${name} = ${name}; }." + else if builtins.hasAttr name python3.pkgs then + python3.pkgs.${name} + else + throw "Python package '${name}' not found in nixpkgs or customPythonPackages" + ) extraPythonPackages) + )) + ]; inputsFrom = [ (buildTorchExtension buildSet { inherit path rev; }) ]; env = lib.optionalAttrs rocmSupport { PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" buildSet.torch.rocmArchs;