Skip to content

gpu offload host code generation #142097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;
use std::{io, iter, slice};

use object::read::archive::ArchiveFile;
use rustc_abi::{Align, Size};
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule, ThinShared};
use rustc_codegen_ssa::back::symbol_export;
use rustc_codegen_ssa::back::write::{CodegenContext, FatLtoInput};
Expand All @@ -22,14 +23,18 @@ use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel
use rustc_session::config::{self, CrateType, Lto};
use tracing::{debug, info};

use llvm::Linkage::*;

use crate::back::write::{
self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode,
};
use crate::builder::{SBuilder, UNNAMED};
use crate::errors::{
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
};
use crate::common::AsCCharPtr;
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, build_string};
use crate::llvm::{self, build_string, Linkage};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};

/// We keep track of the computed LTO cache keys from the previous
Expand All @@ -39,11 +44,11 @@ const THIN_LTO_KEYS_INCR_COMP_FILE_NAME: &str = "thin-lto-past-keys.bin";
fn crate_type_allows_lto(crate_type: CrateType) -> bool {
match crate_type {
CrateType::Executable
| CrateType::Dylib
| CrateType::Staticlib
| CrateType::Cdylib
| CrateType::ProcMacro
| CrateType::Sdylib => true,
| CrateType::Dylib
| CrateType::Staticlib
| CrateType::Cdylib
| CrateType::ProcMacro
| CrateType::Sdylib => true,
CrateType::Rlib => false,
}
}
Expand Down Expand Up @@ -113,7 +118,7 @@ fn prepare_lto(
cgcx.prof.generic_activity("LLVM_lto_generate_symbols_below_threshold");
symbols_below_threshold
.extend(exported_symbols[&cnum].iter().filter_map(symbol_filter));
}
}

let archive_data = unsafe {
Mmap::map(std::fs::File::open(&path).expect("couldn't open rlib"))
Expand All @@ -127,7 +132,7 @@ fn prepare_lto(
std::str::from_utf8(c.name()).ok().map(|name| (name.trim(), c))
})
})
.filter(|&(name, _)| looks_like_rust_object_file(name));
.filter(|&(name, _)| looks_like_rust_object_file(name));
for (name, child) in obj_files {
info!("adding bitcode from {}", name);
match get_bitcode_slice_from_object_data(
Expand Down Expand Up @@ -295,7 +300,7 @@ fn fat_lto(
let cost = unsafe { llvm::LLVMRustModuleCost(module.module_llvm.llmod()) };
(cost, i)
})
.max();
.max();

// If we found a costliest module, we're good to go. Otherwise all our
// inputs were serialized which could happen in the case, for example, that
Expand Down Expand Up @@ -506,7 +511,7 @@ fn thin_lto(
symbols_below_threshold.as_ptr(),
symbols_below_threshold.len(),
)
.ok_or_else(|| write::llvm_err(dcx, LlvmError::PrepareThinLtoContext))?;
.ok_or_else(|| write::llvm_err(dcx, LlvmError::PrepareThinLtoContext))?;

let data = ThinData(data);

Expand Down Expand Up @@ -628,6 +633,7 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
llvm::set_rust_rules(true);
}


pub(crate) fn run_pass_manager(
cgcx: &CodegenContext<LlvmCodegenBackend>,
dcx: DiagCtxtHandle<'_>,
Expand All @@ -653,6 +659,7 @@ pub(crate) fn run_pass_manager(
// We then run the llvm_optimize function a second time, to optimize the code which we generated
// in the enzyme differentiation pass.
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
let enable_gpu = config.offload.contains(&config::Offload::Enable);
let stage = if thin {
write::AutodiffStage::PreAD
} else {
Expand All @@ -667,6 +674,13 @@ pub(crate) fn run_pass_manager(
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}

if cfg!(llvm_enzyme) && enable_gpu && !thin {
dbg!(&enable_gpu);
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
}

if cfg!(llvm_enzyme) && enable_ad && !thin {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
Expand Down Expand Up @@ -935,7 +949,7 @@ impl ThinLTOKeysMap {
.expect("Invalid ThinLTO module key");
(module_name_to_str(name).to_string(), key)
})
.collect();
.collect();
Self { keys }
}
}
Expand Down
179 changes: 119 additions & 60 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Deref;
use std::{iter, ptr};

pub(crate) mod autodiff;
pub(crate) mod gpu_offload;

use libc::{c_char, c_uint, size_t};
use rustc_abi as abi;
Expand Down Expand Up @@ -88,6 +89,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
};
call
}

}

impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
Expand Down Expand Up @@ -118,6 +120,63 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
}
bx
}

pub(crate) fn my_alloca2(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
let val = unsafe {
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
// Cast to default addrspace if necessary
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
};
if name != "" {
let name = std::ffi::CString::new(name).unwrap();
unsafe {llvm::set_value_name(val, &name.as_bytes())};
}
val
}

pub(crate) fn inbounds_gep(
&mut self,
ty: &'ll Type,
ptr: &'ll Value,
indices: &[&'ll Value],
) -> &'ll Value {
unsafe {
llvm::LLVMBuildGEPWithNoWrapFlags(
self.llbuilder,
ty,
ptr,
indices.as_ptr(),
indices.len() as c_uint,
UNNAMED,
GEPNoWrapFlags::InBounds,
)
}
}

pub(crate) fn store(
&mut self,
val: &'ll Value,
ptr: &'ll Value,
align: Align,
) -> &'ll Value {
debug!("Store {:?} -> {:?}", val, ptr);
assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer);
unsafe {
let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr);
llvm::LLVMSetAlignment(store, align.bytes() as c_uint);
store
}
}

pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
unsafe {
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
load
}
}

}

/// Empty string, to be used where LLVM expects an instruction name, indicating
Expand Down Expand Up @@ -1261,7 +1320,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
unsafe {
llvm::LLVMBuildCleanupRet(self.llbuilder, funclet.cleanuppad(), unwind)
.expect("LLVM does not have support for cleanupret");
}
}
}

fn catch_pad(&mut self, parent: &'ll Value, args: &[&'ll Value]) -> Funclet<'ll> {
Expand Down Expand Up @@ -1631,14 +1690,14 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
debug!(
"type mismatch in function call of {:?}. \
Expected {:?} for param {}, got {:?}; injecting bitcast",
llfn, expected_ty, i, actual_ty
llfn, expected_ty, i, actual_ty
);
self.bitcast(actual_val, expected_ty)
} else {
actual_val
}
})
.collect();
.collect();

Cow::Owned(casted_args)
}
Expand Down Expand Up @@ -1791,48 +1850,48 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
let is_indirect_call = unsafe { llvm::LLVMRustIsNonGVFunctionPointerTy(llfn) };
if self.tcx.sess.is_sanitizer_cfi_enabled()
&& let Some(fn_abi) = fn_abi
&& is_indirect_call
{
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::CFI)
&& is_indirect_call
{
return;
}
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::CFI)
{
return;
}

let mut options = cfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(cfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(cfi::TypeIdOptions::NORMALIZE_INTEGERS);
}
let mut options = cfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(cfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(cfi::TypeIdOptions::NORMALIZE_INTEGERS);
}

let typeid = if let Some(instance) = instance {
cfi::typeid_for_instance(self.tcx, instance, options)
} else {
cfi::typeid_for_fnabi(self.tcx, fn_abi, options)
};
let typeid_metadata = self.cx.typeid_metadata(typeid).unwrap();
let dbg_loc = self.get_dbg_loc();

// Test whether the function pointer is associated with the type identifier.
let cond = self.type_test(llfn, typeid_metadata);
let bb_pass = self.append_sibling_block("type_test.pass");
let bb_fail = self.append_sibling_block("type_test.fail");
self.cond_br(cond, bb_pass, bb_fail);

self.switch_to_block(bb_fail);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
}
self.abort();
self.unreachable();
let typeid = if let Some(instance) = instance {
cfi::typeid_for_instance(self.tcx, instance, options)
} else {
cfi::typeid_for_fnabi(self.tcx, fn_abi, options)
};
let typeid_metadata = self.cx.typeid_metadata(typeid).unwrap();
let dbg_loc = self.get_dbg_loc();

// Test whether the function pointer is associated with the type identifier.
let cond = self.type_test(llfn, typeid_metadata);
let bb_pass = self.append_sibling_block("type_test.pass");
let bb_fail = self.append_sibling_block("type_test.fail");
self.cond_br(cond, bb_pass, bb_fail);

self.switch_to_block(bb_fail);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
}
self.abort();
self.unreachable();

self.switch_to_block(bb_pass);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
self.switch_to_block(bb_pass);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
}
}
}
}

// Emits KCFI operand bundles.
Expand All @@ -1847,31 +1906,31 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
let kcfi_bundle = if self.tcx.sess.is_sanitizer_kcfi_enabled()
&& let Some(fn_abi) = fn_abi
&& is_indirect_call
{
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::KCFI)
{
return None;
}
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::KCFI)
{
return None;
}

let mut options = kcfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(kcfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(kcfi::TypeIdOptions::NORMALIZE_INTEGERS);
}
let mut options = kcfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(kcfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(kcfi::TypeIdOptions::NORMALIZE_INTEGERS);
}

let kcfi_typeid = if let Some(instance) = instance {
kcfi::typeid_for_instance(self.tcx, instance, options)
let kcfi_typeid = if let Some(instance) = instance {
kcfi::typeid_for_instance(self.tcx, instance, options)
} else {
kcfi::typeid_for_fnabi(self.tcx, fn_abi, options)
};

Some(llvm::OperandBundleBox::new("kcfi", &[self.const_u32(kcfi_typeid)]))
} else {
kcfi::typeid_for_fnabi(self.tcx, fn_abi, options)
None
};

Some(llvm::OperandBundleBox::new("kcfi", &[self.const_u32(kcfi_typeid)]))
} else {
None
};
kcfi_bundle
}

Expand Down
Loading