Description
For array creation functions, device support will be needed, unless we intend to only support operations on the default device. Otherwise what will happen if any function that creates a new array (e.g. create the output array with empty()
before filling it with the results of some computation) is that the new array will be on the default device, and an exception will be raised if an input array is on a non-default device.
We discussed this in the Aug 27th call, and the preference was to do something PyTorch-like, perhaps a simplified version to start with (we may not need the context manager part), as the most robust option. Summary of some points that were made:
- TensorFlow has an issue where its
.shape
attribute is also a tensor, and that interacts badly with its context manager approach to specifying devices - because metadata like.shape
typically should live on the host, not on an accelerator. - PyTorch uses a mix of a default device, a context manager, and
device=
keywords - JAX also has a context manager-like approach; it has a global default that can be set, and then
pmap
s can be decorated to override that. The different with other libraries that use a context is that JAX is fairly (too) liberal about implicit device copies. - It'd be best for operations where data is not all on the same device to raise an exception. Implicit device transfers are making it very hard to get a good performance story.
- Propagating device assignments through operations is important.
- Control over where operations get executed is important; trying to be fully implicit doesn't scale to situation with multiple GPUs
- It may not make sense to add syntax for device support for libraries that only support a single device (i.e., CPU).
Links to the relevant docs for each library:
- PyTorch: https://pytorch.org/docs/stable/notes/cuda.html
- TensorFlow: https://www.tensorflow.org/api_docs/python/tf/device
- CuPy: https://docs.cupy.dev/en/stable/tutorial/basic.html#current-device
- JAX: https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
- MXNet: https://mxnet.apache.org/versions/1.6/api/python/docs/api/mxnet/context/index.html
Next step should be to write up a proposal for something PyTorch-like.
Activity
szha commentedon Sep 14, 2020
+1 to this. We will need to define an operation to explicitly move the data around among different devices, along with a canonical way of specifying the target device.
oleksandr-pavlyk commentedon Nov 6, 2020
Goals
It is not uncommon nowadays to get a laptop that has a CPU, an integrated GPU and a discrete GPU. Device selection specification should allow to select among the subsets of supported devices, but ensuring that each computational device can be unambiguously referenced.
Moreover, the specification should allow to select among several graphics cards, or more generally accelerators, of the same kind (say among multiple GPUs from the same vendor).
Statements about offloading
It is worth noting that device really refers to a tuple of (hardware, driver), i.e. an NVidia card can be programmed using either CUDA or using OpenCL, an AMD card can be programmed either using ROCm or OpenCL, an Intel card can be programmed using Level-Zero or OpenCL, etc. It may not be very important for a Python user to be able to select among available drivers dynamically. It is reasonable that an array implementation selects a driver for the device at module initialization stage.
To work with a device an associated runtime keeps a structure,
context
, that records the state of the device (c.f. what-is-cuda-context, open-cl-context) as well as information needed for synchronization between the host and the device. An array library may silently create context for a device. CUDA has a context created by CUDA runtime/user for each CUDA device (c.f. cuCtxCreate), while in an OpenCL/SYCL one can in principle create a context for a subset of devices from the same SYCL platform. In SYCL, a context also registers the asynchronous exception handler which array library implementation may set. For example,context(const vector_class<device> &deviceList, async_handler asyncHandler = {})
Runtime facilitates the task of imposing order on the sequence of tasks to be offloaded for an asynchronous execution at a device. CUDA provides
stream
, and SYCL providesqueue
. A CUDA stream executes a single operation sequence on a CUDA device (c.f. torch_cuda_stream). Programming CUDA requires understanding of the computation graph, and multiple concurrent streams may be used to execute segments with no data dependency. SYCL queue allows user to specify task dependencies at the time of scheduling and then SYCL run-time executes tasks honoring those dependencies [see section 4.9 of 2020 SYCL provisional spec].Ultimately, user's selection of a device must allow array library to locate a queue/stream to submit tasks to, and for optimal performance's sake streams/queues should be reused once created. Reusing SYCL queue has an additional benefit of keeping track of data dependency while executing concurrently.
Python user wishing to offload a computation to GPU should be aware of stream/queue, but for most users they are going to be created once by an array implementation per device and reused through the session. Power users may want to explore use of multiple streams, or define task dependencies on SYCL tasks.
Toward the proposal
An array library targeting devices needs to know the device the array was allocated in. In the case of SYCL and OpenCL it also needs to know the context to which the memory was bound, for example to be able to free the memory.
In the case of OpenCL, one can use function
clGetMemObjInfo
to get the OpenCL context thatcl_mem
memory objects was created with.However when using SYCL 2020's USM pointers, the context can not be easily looked up, so array library must either store it with each array object instance, or it should be stored in a global structure associated with the device and all operations involving the device must use that context, as is the case with PyTorch.
Thus to address a device in an array library, a user needs to specify a library-specific layer where the library stores backend-specific (backend here being software layer to work with devices, e.g. OpenCL, CUDA, SYCL, etc.) objects (context, queue/stream, etc.) associated with supported devices, as well as an identifier of a device in the backend.
A library may choose to query its backend for all addressable devices, store them in an array and refer to devices by their position therein. Such position should be deterministic between different Python sessions.
Considering an array library that uses SYCL runtime as its backend, devices can be further differentiated based on their kind: GPU devices, Accelerator devices, CPU devices, Custom devices. It would be appropriate for the array library to keep separate arrays of addressable devices per each kind.
Thus the device specification emerges to be a triple ('backend', 'device_kind', relative_id). Since backends would vary from array library implementations, array library must be able to provide a device with some
elements, but not all of the tuple omitted. Underspecified tuple is then understood as a device filter (c.f. Filter Selector SYCL extension proposal), and the array library chooses the most appropriate device to use at its discretion.
A portable code will then only use (device kind, relative id).
Array object should carry, or should be able to figure out the device it was allocated on (e.g. by implementing
array_instance.device
), and a user should be able to specify the device a new array is to be allocated on as well, which either calls fordevice=
keyword for every function that may create a new array, with some sensible default value (inferdevice=
from inputs, or use a default/current device settable by a context manager, or via an explicit function.)Computations on device arrays must be submitted to a stream/queue. Implementation kernels must be able to access the data, so the queue must use the same runtime
context
that was used when array was allocated. Should this not be the same the implementation should either raise an exception (preferred), or invoke a copy via host.The user should be able to specify the stream/queue where the kernel are to be submitted, like it is the case with native libraries (c.f. in cuBLAS see
cublasSetStream
, in SYCL-BLAS seeexecutor
, in oneMKL see execution model, in oneDNN seednnl::sycl_interop::make_stream
).Ideally this should be seldom needed, so an optional keyword seems heavy-weighted, but could be a solution, with defaults driven by device array inputs and the default queue settings of the library, controllable via context and/or via explicit function.
rgommers commentedon Nov 30, 2020
I agree, this is what libraries currently do so there's no need for something more complex.
Stream management seems like something that is done at the implementation level, and does not show up in a Python level API.
I'm not sure this is right. It's of course very helpful to understand what's happening under the hood, it can allow users to write more performant code. But it's not strictly necessary, and the vast majority of Python end users will actually not understand this while happily writing ML code to run on a GPU for example.
Note that there typically are ways to control this from Python (e.g., for PyTorch,
torch.cuda.synchronize
,torch.cuda.stream
, etc.)This doesn't follow from what you wrote before I think. An array library needs to deal with this (e.g., as you wrote, "store it with each array object instance"), but that should be an implementation detail.
Device IDs should already be deterministic right? E.g. the device ID of writing
'gpu:0'
(or'cuda:0'
, depending on which library is used) should give you the GPU with the actual device ID of0
thatnvidia-smi
shows you as well, and that you can also control withCUDA_VISIBLE_DEVICES
.Agreed it needs kind + ID. No need for it to be relative though I'd think.
Agree with all of this.
All true but out of scope I'd say, it's per-library and may be implementation-specific.
Even though existing libraries do offer APIs for stream control, it's not often used and it's not clear that we can point at anything being best practice / the right thing to adopt. I'd say we should put this out of scope. Mixing multiple libraries + stream control is immature, I think we also had a discussion around the array interchange protocol here (which doesn't contain stream info).
rgommers commentedon Nov 30, 2020
Some thoughts on API
Syntax:
device=
keyword for creation functionsdevice=
keyword should take a string representation as well as an instance of adevice
object.device
object itself should take the same string representation in its constructordevice
object should also provide a string attribute, to give a portable representation which is again a valid device-specifying string (I proposedevice.str
). TBD: also provide.kind
and.index
separately, or do not rely on any other properties fordevice
instances?'device_kind:id'
, with:id'
optional (e.g. doesn't apply to'cpu'
). All lower-case, with kind strings'cpu'
,'gpu'
(note, better than'cuda'
, and applies to AMD GPUs too),'tpu'
. No other strings needed at this time (?)..to(device)
method.device
property, which returns adevice
object instanceBehaviour:
device
passing for portability.device
object instances are only meant to be consumed by the library that produced them - the string attribute can be used for portability between libraries.Out of scope:
To add or not add a context manager:
A context manager for controlling the default device is present in all libraries except NumPy. Concerns are
The main upside would probably be that since most libraries have context managers now, it'd be nice for their users to get a context manager in this API standard - easier to migrate already written code.
rgommers commentedon Nov 30, 2020
I looked at SYCL some more, since I was only very vaguely aware of what it is. @oleksandr-pavlyk please correct me if I'm wrong, but as far as I can tell it's orthogonal to anything one would want to know a Python user or array-consuming library author to know about. They need to know about the actual hardware they can use to execute their code, but (beyond install headaches) won't really care about if it's CUDA/OpenCL/ROCm/SYCL/whatever under the hood.
oleksandr-pavlyk commentedon Nov 30, 2020
With SYCL, one writes a kernel once, compile it with a SYCL compiler to an IR, and then you can submit it to different queues targeting different devices (i.e. CPU, GPU, FPGA, etc.).
This example constructs a Python extension, compiled with Intel's DPCPP compiler, to compute column-wise sums of an array.
Running it on CPU/GPU is a matter of changing a queue to submit the work to:
Array consuming library author need not be aware of this, I thought, just as he/she need not be aware of which array implementation is powering the application.
Complete the "device support" section
agarwal-ashish commentedon Dec 3, 2020
Some thoughts and clarifications:
TensorFlow's ndarray.shape returning an ndarray is a behavior that will be rolled back. Tensor.shape's behavior is to return a TensorShape object which can represent incomplete shapes as well, and that will carry over to ndarray as well.
It is not clear why device needs to be part of the array creation APIs. Context managers can allow mutating global state representing the current device which can be used in the runtime for these calls. Device setting code would only execute when the context manager is entered / exited. Passing device per call can be an unnecessary cost. Also it may force the current device to be known which can be hard for generic library code and may require it to query the existing device from somewhere.
Also, I am not sure we should enforce constraints on where inputs and outputs can be placed for an operation. Such constraints can make it harder to write portable library code where you don't control the inputs and may have to start by copying all inputs to the same device. Tensorflow runtime is allowed to copy inputs to the correct device if needed. Also there are policies on hard / soft device placement which allow TensorFlow to override user specified placement in cases where the placement is infeasible or sub-optimal. One can further imagine a dynamic placement scenario in cases of async execution.
In addition, outputs may some times need to reside on different device as compared to inputs. Examples often involve operations involving metadata (shape, size) that typically resides on the host.
Device placement is generally a "policy" and I think we should leave it as a framework detail instead of having it in the API specification. I am not opposed to reserving a device property in the ndarray API, but I don't think we should put constraints on how the device placement should be done.
rgommers commentedon Dec 3, 2020
That's good to know. In that case I'll remove the note on that, no point in mentioning it if it's being phased out.
The "mutating global state" points at the exact problem with context managers. Having global state generally makes it harder to write correct code. For the person writing that code it may be fine to keep that all in their head, but it affects any library call that gets invoked. Which is probably still fine in single-device situations (e.g. switch between CPU and one GPU), but beyond that it gets tricky.
The consensus of our conversation in September was that a context manager isn't always enough, and that the PyTorch model was more powerful. That still left open whether we should also add a context manager though.
Re cost - do you mean cost in verbosity? Passing through a keyword shouldn't have significant performance cost.
I think the typical pattern would be to either use the default, or obtain it from the local context. E.g.
And only in more complex situations would the actual device need to be known explicitly.
That is a good question, should it be enforced or just recommended? Having device transfers be explicit is usually better (implicit transfers can make for hard to track down performance issues), but perhaps not always.
Interesting, I'm not familiar with this hard/soft distinction, will look at the TF docs.
That should not be a problem if shape and size aren't arrays, but either custom objects or tuples/ints?
That may be a good idea. Would be great to discuss in more detail later today.
array.to_device
#171