Skip to content

Allow Custom Classes to register a handler for .to operations  #51994

Open
@narendasan

Description

@narendasan

🚀 Feature

We would want to be able to register a handler for .to like we do for .def_pickle using torchbind to allow us to define custom behavior to move a custom class from one device to another or to error out if such a move is impossible. Ideally this .to function would be called either by the user directly on an instance of the class or recursively when the user calls .to on the module owning the instance, similar to how it works for tensors owned by modules today.

Motivation

In TRTorch we store a custom class managing a TensorRT engine as an attribute of a ScriptModule. However TRT engines once initialized are device specific so we would like to open up the possibility for users to move these engines between devices using a standard PyTorch convention.

Pitch

Ideally we would like to see something like this possible

import torch
import trtorch

# Create model on device 0
model = MyModel()
ts_model = torch.jit.script(model).to("cuda:0")
trt_model = trtorch.compile(ts_model, {...}) # or trt_model = torch._C._jit_to_backend("tensorrt", ts_model, ...)

# Move module (internal tensors and attributes that have a .to registration) to device 1 
trt_model.to("cuda:1")

Alternatives

We could write an independent .to method that works when invoked on an instance but in the case that that instance is owned by a module I am not sure what the process is for users to dig out a reference to the attribute to call the method on.

There might be a better way to store the custom class so that this wouldn't effect all attributes as well. I am not too familiar with the common use cases for attributes.

Additional context

This is our current custom class: https://github.com/NVIDIA/TRTorch/blob/master/core/runtime/TRTEngine.cpp

and how we register it as an attribute in the module: https://github.com/NVIDIA/TRTorch/blob/6442fce997e1506d859fab789527fe1e282f683f/core/compiler.cpp#L57

Some additional context can be found here as well for a tangentially related feature in TRTorch: pytorch/TensorRT#311

It was pointed out in the PyTorch slack that this is where the .to call on the module is resolved, presumably this would need to be modified to call .to on attributes of script modules as well
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp#L78-L102

cc @gmagogsfm

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: custom-operatorscustom operators, custom ops, custom-operators, custom-opsoncall: jitAdd this issue/PR to JIT oncall triage queueweeks

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions