Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions crates/uv-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5371,6 +5371,21 @@ pub struct ToolRunArgs {
#[arg(long)]
pub python_platform: Option<TargetTriple>,

/// The backend to use when fetching packages in the PyTorch ecosystem (e.g., `cpu`, `cu126`, or `auto`)
///
/// When set, uv will ignore the configured index URLs for packages in the PyTorch ecosystem,
/// and will instead use the defined backend.
///
/// For example, when set to `cpu`, uv will use the CPU-only PyTorch index; when set to `cu126`,
/// uv will use the PyTorch index for CUDA 12.6.
///
/// The `auto` mode will attempt to detect the appropriate PyTorch index based on the currently
/// installed CUDA drivers.
///
/// This option is in preview and may change in any future release.
#[arg(long, value_enum, env = EnvVars::UV_TORCH_BACKEND)]
pub torch_backend: Option<TorchMode>,

#[arg(long, hide = true)]
pub generate_shell_completion: Option<clap_complete_command::Shell>,
}
Expand Down Expand Up @@ -5547,6 +5562,21 @@ pub struct ToolInstallArgs {
/// `--python-platform` option is intended for advanced use cases.
#[arg(long)]
pub python_platform: Option<TargetTriple>,

/// The backend to use when fetching packages in the PyTorch ecosystem (e.g., `cpu`, `cu126`, or `auto`)
///
/// When set, uv will ignore the configured index URLs for packages in the PyTorch ecosystem,
/// and will instead use the defined backend.
///
/// For example, when set to `cpu`, uv will use the CPU-only PyTorch index; when set to `cu126`,
/// uv will use the PyTorch index for CUDA 12.6.
///
/// The `auto` mode will attempt to detect the appropriate PyTorch index based on the currently
/// installed CUDA drivers.
///
/// This option is in preview and may change in any future release.
#[arg(long, value_enum, env = EnvVars::UV_TORCH_BACKEND)]
pub torch_backend: Option<TorchMode>,
}

#[derive(Args)]
Expand Down
1 change: 1 addition & 0 deletions crates/uv-cli/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ pub fn resolver_options(
exclude_newer_package.unwrap_or_default(),
),
link_mode,
torch_backend: None,
no_build: flag(no_build, build, "build"),
no_build_package: Some(no_build_package),
no_binary: flag(no_binary, binary, "binary"),
Expand Down
14 changes: 9 additions & 5 deletions crates/uv-settings/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ pub struct ResolverOptions {
pub config_settings_package: Option<PackageConfigSettings>,
pub exclude_newer: ExcludeNewer,
pub link_mode: Option<LinkMode>,
pub torch_backend: Option<TorchMode>,
pub upgrade: Option<Upgrade>,
pub build_isolation: Option<BuildIsolation>,
pub no_build: Option<bool>,
Expand Down Expand Up @@ -404,6 +405,7 @@ pub struct ResolverInstallerOptions {
pub exclude_newer: Option<ExcludeNewerValue>,
pub exclude_newer_package: Option<ExcludeNewerPackage>,
pub link_mode: Option<LinkMode>,
pub torch_backend: Option<TorchMode>,
pub compile_bytecode: Option<bool>,
pub no_sources: Option<bool>,
pub upgrade: Option<Upgrade>,
Expand All @@ -412,7 +414,6 @@ pub struct ResolverInstallerOptions {
pub no_build_package: Option<Vec<PackageName>>,
pub no_binary: Option<bool>,
pub no_binary_package: Option<Vec<PackageName>>,
pub torch_backend: Option<TorchMode>,
}

impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
Expand All @@ -438,6 +439,7 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
exclude_newer,
exclude_newer_package,
link_mode,
torch_backend,
compile_bytecode,
no_sources,
upgrade,
Expand All @@ -448,7 +450,6 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
no_build_package,
no_binary,
no_binary_package,
torch_backend,
} = value;
Self {
index,
Expand All @@ -473,6 +474,7 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
exclude_newer,
exclude_newer_package,
link_mode,
torch_backend,
compile_bytecode,
no_sources,
upgrade: Upgrade::from_args(
Expand All @@ -488,7 +490,6 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
no_build_package,
no_binary,
no_binary_package,
torch_backend,
}
}
}
Expand Down Expand Up @@ -1925,6 +1926,7 @@ impl From<ResolverInstallerSchema> for ResolverOptions {
extra_build_dependencies: value.extra_build_dependencies,
extra_build_variables: value.extra_build_variables,
no_sources: value.no_sources,
torch_backend: value.torch_backend,
}
}
}
Expand Down Expand Up @@ -2004,6 +2006,7 @@ pub struct ToolOptions {
pub no_build_package: Option<Vec<PackageName>>,
pub no_binary: Option<bool>,
pub no_binary_package: Option<Vec<PackageName>>,
pub torch_backend: Option<TorchMode>,
}

impl From<ResolverInstallerOptions> for ToolOptions {
Expand Down Expand Up @@ -2034,6 +2037,7 @@ impl From<ResolverInstallerOptions> for ToolOptions {
no_build_package: value.no_build_package,
no_binary: value.no_binary,
no_binary_package: value.no_binary_package,
torch_backend: value.torch_backend,
}
}
}
Expand Down Expand Up @@ -2068,7 +2072,7 @@ impl From<ToolOptions> for ResolverInstallerOptions {
no_build_package: value.no_build_package,
no_binary: value.no_binary,
no_binary_package: value.no_binary_package,
torch_backend: None,
torch_backend: value.torch_backend,
}
}
}
Expand Down Expand Up @@ -2150,7 +2154,7 @@ pub struct OptionsWire {
// `crates/uv-workspace/src/pyproject.rs`. The documentation lives on that struct.
// They're respected in both `pyproject.toml` and `uv.toml` files.
override_dependencies: Option<Vec<Requirement<VerbatimParsedUrl>>>,
exclude_dependencies: Option<Vec<uv_normalize::PackageName>>,
exclude_dependencies: Option<Vec<PackageName>>,
constraint_dependencies: Option<Vec<Requirement<VerbatimParsedUrl>>>,
build_constraint_dependencies: Option<Vec<Requirement<VerbatimParsedUrl>>>,
environments: Option<SupportedEnvironments>,
Expand Down
1 change: 1 addition & 0 deletions crates/uv/src/commands/build_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ async fn build_impl(
upgrade: _,
build_options,
sources,
torch_backend: _,
} = settings;

// Determine the source to build.
Expand Down
1 change: 1 addition & 0 deletions crates/uv/src/commands/project/lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ async fn do_lock(
upgrade,
build_options,
sources,
torch_backend: _,
} = settings;

if !preview.is_enabled(PreviewFeatures::EXTRA_BUILD_DEPENDENCIES)
Expand Down
70 changes: 70 additions & 0 deletions crates/uv/src/commands/project/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use uv_resolver::{
use uv_scripts::Pep723ItemRef;
use uv_settings::PythonInstallMirrors;
use uv_static::EnvVars;
use uv_torch::{TorchSource, TorchStrategy};
use uv_types::{BuildIsolation, EmptyInstalledPackages, HashStrategy};
use uv_virtualenv::remove_virtualenv;
use uv_warnings::{warn_user, warn_user_once};
Expand Down Expand Up @@ -278,6 +279,9 @@ pub(crate) enum ProjectError {
#[error(transparent)]
RetryParsing(#[from] uv_client::RetryParsingError),

#[error(transparent)]
Accelerator(#[from] uv_torch::AcceleratorError),

#[error(transparent)]
Anyhow(#[from] anyhow::Error),
}
Expand Down Expand Up @@ -1723,6 +1727,7 @@ pub(crate) async fn resolve_names(
prerelease: _,
resolution: _,
sources,
torch_backend,
upgrade: _,
},
compile_bytecode: _,
Expand All @@ -1731,10 +1736,27 @@ pub(crate) async fn resolve_names(

let client_builder = client_builder.clone().keyring(*keyring_provider);

// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(mode, source, interpreter.platform().os())
})
.transpose()
.ok()
.flatten();

// Initialize the registry client.
let client = RegistryClientBuilder::new(client_builder, cache.clone())
.index_locations(index_locations.clone())
.index_strategy(*index_strategy)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
Expand Down Expand Up @@ -1880,6 +1902,7 @@ pub(crate) async fn resolve_environment(
upgrade: _,
build_options,
sources,
torch_backend,
} = settings;

// Respect all requirements from the provided sources.
Expand All @@ -1900,10 +1923,33 @@ pub(crate) async fn resolve_environment(
let marker_env = pip::resolution_markers(None, python_platform, interpreter);
let python_requirement = PythonRequirement::from_interpreter(interpreter);

// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(
mode,
source,
python_platform
.map(|t| t.platform())
.as_ref()
.unwrap_or(interpreter.platform())
.os(),
)
})
.transpose()?;

// Initialize the registry client.
let client = RegistryClientBuilder::new(client_builder, cache.clone())
.index_locations(index_locations.clone())
.index_strategy(*index_strategy)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
Expand Down Expand Up @@ -2232,6 +2278,7 @@ pub(crate) async fn update_environment(
prerelease,
resolution,
sources,
torch_backend,
upgrade,
},
compile_bytecode,
Expand Down Expand Up @@ -2302,10 +2349,33 @@ pub(crate) async fn update_environment(
}
}

// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(
mode,
source,
python_platform
.map(|t| t.platform())
.as_ref()
.unwrap_or(interpreter.platform())
.os(),
)
})
.transpose()?;

// Initialize the registry client.
let client = RegistryClientBuilder::new(client_builder, cache.clone())
.index_locations(index_locations.clone())
.index_strategy(*index_strategy)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
Expand Down
1 change: 1 addition & 0 deletions crates/uv/src/commands/project/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ pub(super) async fn do_sync(
prerelease: PrereleaseMode::default(),
resolution: ResolutionMode::default(),
sources,
torch_backend: None,
upgrade: Upgrade::default(),
};
script_extra_build_requires(
Expand Down
1 change: 1 addition & 0 deletions crates/uv/src/commands/project/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ pub(crate) async fn tree(
upgrade: _,
build_options: _,
sources: _,
torch_backend: _,
} = &settings;

let capabilities = IndexCapabilities::default();
Expand Down
8 changes: 7 additions & 1 deletion crates/uv/src/commands/tool/install.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use uv_python::{
use uv_requirements::{RequirementsSource, RequirementsSpecification};
use uv_settings::{PythonInstallMirrors, ResolverInstallerOptions, ToolOptions};
use uv_tool::InstalledTools;
use uv_warnings::warn_user;
use uv_warnings::{warn_user, warn_user_once};
use uv_workspace::WorkspaceCache;

use crate::commands::ExitStatus;
Expand Down Expand Up @@ -76,6 +76,12 @@ pub(crate) async fn install(
printer: Printer,
preview: Preview,
) -> Result<ExitStatus> {
if settings.resolver.torch_backend.is_some() {
warn_user_once!(
"The `--torch-backend` option is experimental and may change without warning."
);
}

let reporter = PythonDownloadReporter::single(printer);

let python_request = python.as_deref().map(PythonRequest::parse);
Expand Down
6 changes: 6 additions & 0 deletions crates/uv/src/commands/tool/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ pub(crate) async fn run(
.is_some_and(|ext| ext.eq_ignore_ascii_case("py") || ext.eq_ignore_ascii_case("pyw"))
}

if settings.resolver.torch_backend.is_some() {
warn_user_once!(
"The `--torch-backend` option is experimental and may change without warning."
);
}

// Read from the `.env` file, if necessary.
if !no_env_file {
for env_file_path in env_file.iter().rev().map(PathBuf::as_path) {
Expand Down
Loading
Loading