Skip to content
This repository was archived by the owner on Oct 16, 2023. It is now read-only.

modify offload manager and add linear example #103

Merged
merged 1 commit into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions energonai/nemesis/nemesis_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
------------------------------------------
Class gpu_info and Nemesis_manager
Mainly used for peer memory offloading
------------------------------------------
"""

import sys
import torch
import pynvml
Expand All @@ -6,6 +13,9 @@


class gpu_info:
"""
class used to monitor the status of each gpu device
"""

def __init__(self, device_id: int):
self._handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
Expand Down Expand Up @@ -51,22 +61,35 @@ def __init__(self):
self._gpu_info = {"cuda:{}".format(i): gpu_info(i) for i in range(self._gpu_num)}
self._module_list = list()
self.offload_dict = dict()
self.event_dict = dict()
self.offload_flags = None
self.prefetch_dict = dict()
self.compute_device_dict = dict()
self.module_size = dict()
self.layer_num = -1
self.offload_interval = -1
self.prefetch_layer = 3 # how many layers ahead do we prefetch a offloaded layer
self.free_device = None
self._model = None
# The two cuda streams separately needed for computing and offloading
self.offload_stream = torch.cuda.Stream()
self.compute_stream = torch.cuda.Stream()

def register_model(self, model_):
self._model = model_

def set_free_device(self, free_device):
"""
Call this function to assign the free device where we offload layers to.
"""
self.free_device = free_device

def set_model_info(self, layer_num, offload_interval):
"""
:param layer_num: the number of layers in the model
:param offload_interval: One in how many layers we offload the model
This function should be called in the initialize function of your model.
"""
assert layer_num % offload_interval == 0
self.layer_num = layer_num
self.offload_interval = offload_interval
Expand All @@ -80,7 +103,8 @@ def calculate_module_size(self, module_: torch.nn.Module):
return res_size

def move_module(self, module_: torch.nn.Module, target_device):
module_ = module_.to(target_device, non_blocking=True)
with torch.cuda.stream(self.offload_stream):
module_ = module_.to(target_device, non_blocking=True)

def generate_offload_dict(self):
assert self.layer_num != -1 and self.offload_interval != -1, 'please set layer num and offload interval first'
Expand All @@ -96,18 +120,27 @@ def offload_module(self, module_):
Ne_manager.move_module(module_, free_device)

def apply_hook(self):
"""
This function is used for implmenting pre_hooks of pytorch so as to achieve offloading and prefetch.
PLEASE CALL THIS FUNCTION before inference if you want to enable offloading.
"""
for i in range(len(self._module_list)):
if i % self.offload_interval == 0:
self.offload_dict[id(self._module_list[i])].append(self._module_list[i - 1])
if i % self.offload_interval == self.offload_interval - 3:
self.prefetch_dict[id(self._module_list[i])].append(self._module_list[i + 2])
if self.offload_interval == 2:
if i % self.offload_interval == 0:
self.prefetch_dict[id(self._module_list[i])].append(self._module_list[i + 1])
else:
if i % self.offload_interval == self.offload_interval - self.prefetch_layer:
self.prefetch_dict[id(self._module_list[i])].append(self._module_list[i + self.prefetch_layer - 1])
if len(self.prefetch_dict[id(self._module_list[i])]) + len(self.offload_dict[id(self._module_list[i])]) > 0:
self._module_list[i].register_forward_pre_hook(basic_hook)

def register_module(self, module_: torch.nn.Module, device: str):
self._gpu_info[device].gpu_register_module(module_)
self.offload_dict[id(module_)] = list()
self.prefetch_dict[id(module_)] = list()
self.event_dict[id(module_)] = None
self._module_list.append(module_)
self.compute_device_dict[id(module_)] = device
self.module_size[id(module_)] = self.calculate_module_size(module_)
Expand All @@ -134,6 +167,11 @@ def print_status(self):


def basic_hook(module: torch.nn.Module, input_):
"""
The hook function required by pytorch register_forward_pre_hook function.
We use this function to launch the offloading and prefetching process on the offload stream
so as to achieve overlap.
"""
for tg in Ne_manager.offload_dict[id(module)]:
cur_device = next(tg.parameters()).device
if Ne_manager.compute_device_dict[id(tg)] == "{}:{}".format(cur_device.type, cur_device.index):
Expand All @@ -143,4 +181,7 @@ def basic_hook(module: torch.nn.Module, input_):
Ne_manager.move_module(tg, free_device)
for tg in Ne_manager.prefetch_dict[id(module)]:
Ne_manager.move_module(tg, Ne_manager.compute_device_dict[id(tg)])
with torch.cuda.stream(Ne_manager.offload_stream):
evt_2 = Ne_manager.offload_stream.record_event()
Ne_manager.event_dict[id(tg)] = evt_2
return
90 changes: 90 additions & 0 deletions examples/linear/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import time

from energonai.nemesis.nemesis_manager import Ne_manager
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

compute_device = 'cuda:0' # manually set which device to compute on
offload_flag = True # whether or not to activate offloading

def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True


class single_linear(nn.Module):
def __init__(self, input_dim: int, output_dim: int, bias=False):
super().__init__()
self.weight = torch.empty(output_dim, input_dim)
nn.init.normal_(self.weight)
self.weight = nn.Parameter(self.weight.to(compute_device))
if bias:
self.bias = torch.empty(output_dim)
nn.init.normal_(self.bias)
self.bias = nn.Parameter(self.bias.to(compute_device))
else:
self.bias = None

def forward(self, input_):
output = F.linear(input_, self.weight, self.bias)
return output


class nv_layers(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, layer_num: int):
super().__init__()
self.module_list = list()
for i in range(layer_num):
if i == 0:
temp_layer = single_linear(input_dim, hidden_dim, True)
elif i == layer_num - 1:
temp_layer = single_linear(hidden_dim, output_dim, True)
else:
temp_layer = single_linear(hidden_dim, hidden_dim, True)
Ne_manager.register_module(temp_layer, compute_device)
if Ne_manager.offload_flags[i] and offload_flag:
Ne_manager.offload_module(temp_layer)
self.module_list.append(temp_layer)

def print_device(self):
cnt__ = 0
print("=" * 50)
for mod in self.module_list:
print("layer {} device: ".format(cnt__))
cnt__ += 1
print(next(mod.parameters()).data.device)
print("=" * 50)

def forward(self, input_):
output = input_
for layer_ in self.module_list:
if Ne_manager.event_dict[id(layer_)] is not None:
Ne_manager.compute_stream.wait_event(Ne_manager.event_dict[id(layer_)])
Ne_manager.event_dict[id(layer_)] = None
output = layer_(output)
return output


if __name__ == "__main__":
setup_seed(42)
Ne_manager.set_model_info(12, 6) # register model info
Ne_manager.set_free_device("cuda:1")
# Ne_manager.set_free_device("cpu") # modify here if you want to use cpu as offloading target
model_ = nv_layers(200, 150000, 10, 12)
if offload_flag:
Ne_manager.apply_hook() # call this to activate offloading hooks
input_ = torch.randn((20, 200)).to("cuda:0")
print("init done")
with torch.inference_mode():
for i in range(5):
out_ = model_(input_)
start_ = time.time()
with torch.inference_mode():
for i in range(20):
out_ = model_(input_)
print(time.time() - start_)