Skip to content

Support custom python libraries in dev shell nixland #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions examples/activation/flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@
kernel-builder.lib.genFlakeOutputs {
path = ./.;
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
# Example of adding Python test dependencies directly in the flake
pythonTestDeps = [ "pytest-benchmark" ];
};
}
30 changes: 29 additions & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@
in
builtins.toJSON (nixpkgs.lib.foldl' (acc: system: acc // buildVariants system) { } systems);
genFlakeOutputs =
{ path, rev }:
{
path,
rev,
pythonTestDeps ? [ ],
pythonDevDeps ? [ ],
customPythonPackages ? { },
}:
flake-utils.lib.eachSystem systems (
system:
let
Expand All @@ -71,10 +77,14 @@
devShells = build.torchDevShells {
inherit path;
rev = revUnderscored;
extraPythonPackages = pythonDevDeps;
inherit customPythonPackages;
};
testShells = build.torchExtensionShells {
inherit path;
rev = revUnderscored;
extraPythonPackages = pythonTestDeps;
inherit customPythonPackages;
};
};
packages = rec {
Expand Down Expand Up @@ -182,5 +192,23 @@
)
// {
inherit lib;
# Export the helper function
mkShellsWithExtraPackages =
{
pythonTestDeps ? [ ],
pythonDevDeps ? [ ],
}:
{
genFlakeOutputs =
{ path, rev }:
lib.genFlakeOutputs {
inherit
path
rev
pythonTestDeps
pythonDevDeps
;
};
};
};
}
78 changes: 73 additions & 5 deletions lib/build.nix
Original file line number Diff line number Diff line change
Expand Up @@ -195,34 +195,74 @@ rec {
# Get a development shell with the extension in PYTHONPATH. Handy
# for running tests.
torchExtensionShells =
{ path, rev }:
{
path,
rev,
extraPythonPackages ? [ ],
customPythonPackages ? { },
}:
let
buildConfig = readBuildConfig path;

shellForBuildSet =
{ path, rev }:
buildSet: {
name = torchBuildVersion buildSet;
value =
with buildSet.pkgs;
let
# Function to resolve nixpkgs packages or custom packages
resolvePythonPackage =
name:
if builtins.hasAttr name customPythonPackages then
python3.pkgs.buildPythonPackage {
pname = name;
version = "custom";
src = customPythonPackages.${name};
doCheck = false;
nativeBuildInputs = [
python3.pkgs.setuptools
python3.pkgs.wheel
];
}
else if lib.hasPrefix "git+" name then
throw "ERROR: Git packages like '${name}' require flake inputs. Add as input: ${name} = { url = \"${lib.removePrefix "git+" name}\"; flake = false; } then use customPythonPackages = { ${name} = ${name}; }."
else if builtins.hasAttr name python3.pkgs then
python3.pkgs.${name}
else
throw "Python package '${name}' not found in nixpkgs or customPythonPackages";

# Resolve all packages
allPackages = map resolvePythonPackage extraPythonPackages;
in
mkShell {
buildInputs = [
(python3.withPackages (
ps: with ps; [
ps:
with ps;
[
buildSet.torch
pytest
]
++ allPackages
))
];
shellHook = ''
export PYTHONPATH=${buildTorchExtension buildSet { inherit path rev; }}
'';
};
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
filteredBuildSets = applicableBuildSets buildConfig buildSets;
in
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) filteredBuildSets);

torchDevShells =
{ path, rev }:
{
path,
rev,
extraPythonPackages ? [ ],
customPythonPackages ? { },
}:
let
shellForBuildSet =
buildSet:
Expand All @@ -239,7 +279,35 @@ rec {
build2cmake
kernel-abi-check
];
buildInputs = with pkgs; [ python3.pkgs.pytest ];
buildInputs = with pkgs; [
(python3.withPackages (
ps:
with ps;
[
pytest
]
++ (map (
name:
if builtins.hasAttr name customPythonPackages then
python3.pkgs.buildPythonPackage {
pname = name;
version = "custom";
src = customPythonPackages.${name};
doCheck = false;
nativeBuildInputs = [
python3.pkgs.setuptools
python3.pkgs.wheel
];
}
else if lib.hasPrefix "git+" name then
throw "ERROR: Git packages like '${name}' require flake inputs. Add as input: ${name} = { url = \"${lib.removePrefix "git+" name}\"; flake = false; } then use customPythonPackages = { ${name} = ${name}; }."
else if builtins.hasAttr name python3.pkgs then
python3.pkgs.${name}
else
throw "Python package '${name}' not found in nixpkgs or customPythonPackages"
) extraPythonPackages)
))
];
inputsFrom = [ (buildTorchExtension buildSet { inherit path rev; }) ];
env = lib.optionalAttrs rocmSupport {
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" buildSet.torch.rocmArchs;
Expand Down