diff --git a/energon/cli/__init__.py b/energon/cli/__init__.py new file mode 100644 index 0000000..18423a0 --- /dev/null +++ b/energon/cli/__init__.py @@ -0,0 +1,17 @@ +import click +import typer +from energon.cli.service import service + +app = typer.Typer() + +@app.callback() +def callback(): + """ + Typer app, including Click subapp + """ + +typer_click_object = typer.main.get_command(app) +typer_click_object.add_command(service, "service") + +if __name__ == "__main__": + typer_click_object() \ No newline at end of file diff --git a/energon/cli/service.py b/energon/cli/service.py new file mode 100644 index 0000000..3735224 --- /dev/null +++ b/energon/cli/service.py @@ -0,0 +1,82 @@ +import click +import torch +import energon.server as server +from multiprocessing import Process + +@click.group() +def service(): + pass + + +@service.command() +@click.option("--model_name", default="bert_small", type=str) +@click.option("--model_type", default="bert", type=str) +@click.option("--max_batch_size", default=32, type=int) +@click.option("--tp_init_size", default=1, type=int) +@click.option("--pp_init_size", default=1, type=int) +@click.option("--host", default="127.0.0.1", type=str) +@click.option("--port", default=29400, type=int) +@click.option("--half", is_flag=True, show_default=True) +@click.option("--checkpoint", type=str) +@click.option("--server_host", default="127.0.0.1", type=str) +@click.option("--server_port", default=8005, type=int) +@click.option("--log_level", default="info", type=str) +@click.option("--backend", default="nccl", type=str) +def init(model_name, + model_type, + max_batch_size, + tp_init_size, + pp_init_size, + host, + port, + half, + checkpoint, + server_host, + server_port, + log_level, + backend): + + click.echo(f'*** Energon Init Configurations: *** \n' + f'Model Name: {model_name} \n' + f'Max Batch Size: {max_batch_size} \n' + f'Tensor Parallelism Size: {tp_init_size} \n' + f'Pipeline Parallelism Size: {pp_init_size} \n' + f'Communication Host: {host} \n' + f'Communication Port: {port} \n' + f'Is Half: {half} \n' + f'Checkpoint Path: {checkpoint} \n' + f'Worker Server Host: {server_host} \n' + f'Worker Server Port: {server_port} \n' + f'Unvicorn Log Level: {log_level} \n') + + if half: + dtype = torch.half + else: + dtype = torch.float + + world_size = tp_init_size * pp_init_size + num_worker = world_size - 1 + + engine_port = server_port + worker_port = server_port + 1 + worker_rank = 1 # start from 1 + + process_list = [] + for i in range(num_worker): + p = Process(target=server.launch_worker, + args=(host, port, tp_init_size, pp_init_size, "nccl", 1024, True, worker_rank+i, worker_rank+i, server_host, worker_port+i, log_level)) + p.start() + process_list.append(p) + + server.launch_engine(model_name, + model_type, + max_batch_size, + tp_init_size, + pp_init_size, + host, + port, + dtype, + checkpoint, + server_host, + engine_port, + log_level) \ No newline at end of file diff --git a/energon/engine/engine.py b/energon/engine/engine.py index c0a9253..f14cc0f 100644 --- a/energon/engine/engine.py +++ b/energon/engine/engine.py @@ -1,4 +1,3 @@ -import os import time import torch from torch.nn import Module @@ -68,9 +67,7 @@ def __init__(self, def _init_dist_rpc(self): r''' Based on global_context, init the rpc connection. - ''' - os.environ['MASTER_ADDR'] = self.host - os.environ['MASTER_PORT'] = f'{self.port}' + ''' launch_from_multiprocess(tp_size = self.tp_size, pp_size = self.pp_size, rank = self.rank, local_rank = self.rank, world_size = self.global_world_size, host = self.host, port = self.port) rpc_backend_options=rpc.TensorPipeRpcBackendOptions( num_worker_threads=16) diff --git a/energon/engine/server.py b/energon/engine/server.py index db490f2..2c0bdcc 100644 --- a/energon/engine/server.py +++ b/energon/engine/server.py @@ -1,5 +1,6 @@ import os import uvicorn +import argparse from fastapi import FastAPI import torch.distributed.rpc as rpc from energon.initialize import launch_from_multiprocess @@ -40,4 +41,16 @@ def launch_worker(host="127.0.0.1", port=8005, log_level="info"): global server config = uvicorn.Config(app, host=host, port=port, log_level=log_level) server = uvicorn.Server(config=config) - server.run() \ No newline at end of file + server.run() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="127.0.0.1", help="Iteration") + parser.add_argument("--port", type=int, default=8005, help="Port") + parser.add_argument("--log_level", default="info", type=str) + args = parser.parse_args() + launch_worker(args.host, args.port, args.log_level) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/energon/initialize.py b/energon/initialize.py index 599f09b..56ce2e6 100644 --- a/energon/initialize.py +++ b/energon/initialize.py @@ -118,7 +118,9 @@ def launch_from_multiprocess(tp_size: int = 1, here we provide the multiprocess launch. TODO: only support the single node condition now. """ - + os.environ['MASTER_ADDR'] = host + os.environ['MASTER_PORT'] = f'{port}' + launch(local_rank=local_rank, rank=rank, world_size=world_size, diff --git a/energon/model/__init__.py b/energon/model/__init__.py new file mode 100644 index 0000000..360ac1f --- /dev/null +++ b/energon/model/__init__.py @@ -0,0 +1,2 @@ +from .bert import * +from .gpt import * \ No newline at end of file diff --git a/energon/model/bert/__init__.py b/energon/model/bert/__init__.py new file mode 100644 index 0000000..e8fe553 --- /dev/null +++ b/energon/model/bert/__init__.py @@ -0,0 +1,4 @@ +from .bert import bert_small + + +__all__ = ['bert_small'] diff --git a/energon/model/bert/bert.py b/energon/model/bert/bert.py new file mode 100644 index 0000000..8a846a2 --- /dev/null +++ b/energon/model/bert/bert.py @@ -0,0 +1,323 @@ +import math +from typing import Callable + +import os +import torch +from torch import nn as nn, Tensor, dtype + +from energon.context import ParallelMode +from energon.core import global_context as gpc +from energon.logging import get_dist_logger +from energon.nn.layer.utils import divide, ACT2FN +from energon.nn import Linear1D_Col, Linear1D_Row, Classifier1D +from energon.nn import LayerNorm1D +from energon.nn import VocabParallelEmbedding1D +from energon.utils import get_current_device, is_using_pp + +__all__ = [ + 'BertEmbedding1D' + 'BertMLP1D', + 'BertSelfAttention1D', + 'BertTransformerLayer1D' +] + +from energon.utils.checkpointing import load_checkpoint + + +class BertEmbedding1D(nn.Module): + def __init__(self, + embedding_dim: int, # hidden_size + vocab_size: int, + max_position_embeddings: int, + num_tokentypes: int = 0, + padding_idx: int = 0, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None) -> None: + super().__init__() + self.word_embeddings = VocabParallelEmbedding1D(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype, skip_tp=True) + self.position_embeddings = VocabParallelEmbedding1D(max_position_embeddings, embedding_dim, dtype=dtype, skip_tp=True) + if num_tokentypes > 0: + self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, embedding_dim, dtype=dtype) + else: + self.tokentype_embeddings = None + + # self.LayerNorm = nn.LayerNorm(embedding_dim, eps=layernorm_epsilon, dtype=dtype) + self.LayerNorm = LayerNorm1D(embedding_dim, eps=layernorm_epsilon) + + def forward(self, input_ids, position_ids=None, tokentype_ids=None): + seq_length = input_ids.size(1) + # TODO: register_buffer in advance for position_ids to speedup + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) + + x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) + + if self.tokentype_embeddings is not None and tokentype_ids is not None: + x = x + self.tokentype_embeddings(tokentype_ids) + + x = self.LayerNorm(x) + + return x + + +class BertSelfAttention1D(nn.Module): + def __init__(self, + hidden_size: int, + num_heads: int, + bias: bool = True, + fuse_scale_mask_softmax: bool = False, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None) -> None: + super().__init__() + if hidden_size % num_heads != 0: + raise ValueError( + f"The hidden size ({hidden_size}) is not a multiple of the number of attention ") + self.hidden_size = hidden_size + self.attention_head_size = divide(hidden_size, num_heads) + self.fuse_scale_mask_softmax = fuse_scale_mask_softmax + + self.query_key_value = Linear1D_Col(hidden_size, 3 * hidden_size, bias=bias, dtype=dtype) + + if fuse_scale_mask_softmax: + raise NotImplementedError + + self.dense = Linear1D_Row(hidden_size, hidden_size, bias=True, dtype=dtype, parallel_input=True) + # self.LayerNorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) + self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon) + + + def forward(self, hidden_states, attention_mask=None): + attention_output = self.query_key_value(hidden_states) + all_head_size = attention_output.shape[-1] // 3 + num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads + + new_qkv_shape = attention_output.shape[:-1] + (num_attention_heads, 3*self.attention_head_size) + attention_output = attention_output.view(new_qkv_shape) + attention_output = attention_output.permute(0, 2, 1, 3) + q, k, v = torch.chunk(attention_output, 3, dim = -1) + + attention_output = torch.matmul(q, k.transpose(-1, -2)) + + if self.fuse_scale_mask_softmax: + raise NotImplementedError + else: + attention_output = attention_output / math.sqrt(self.attention_head_size) + if attention_mask is not None: + attention_output = attention_output + attention_mask + attention_output = nn.functional.softmax(attention_output, dim=-1) + + attention_output = torch.matmul(attention_output, v) + attention_output = attention_output.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = attention_output.size()[:-2] + (all_head_size,) + attention_output = attention_output.reshape(new_context_layer_shape) + + attention_output = self.dense(attention_output) + + hidden_states = self.LayerNorm(attention_output + hidden_states) + + return hidden_states + +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) + +class BertMLP1D(nn.Module): + def __init__(self, + hidden_size: int, + mlp_ratio: float, + activation: Callable = gelu_impl, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None, + bias: bool = True): + super().__init__() + intermediate_dim = int(hidden_size * mlp_ratio) + self.layer_0 = Linear1D_Col(hidden_size, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) + self.activation = activation + self.layer_1 = Linear1D_Row(intermediate_dim, hidden_size, bias=bias,dtype=dtype, parallel_input=True) + # self.LayerNorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) + self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon) + + def forward(self, input_tensor): + hidden_states = self.layer_0(input_tensor) + hidden_states = self.activation(hidden_states) + hidden_states = self.layer_1(hidden_states) + + hidden_states = self.LayerNorm(hidden_states+input_tensor) + return hidden_states + +class BertTransformerLayer1D(nn.Module): + def __init__(self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + activation: Callable = gelu_impl, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None, + bias: bool = True, + fuse_scale_mask_softmax: bool = False): + + super().__init__() + + self.attention = BertSelfAttention1D(hidden_size, + num_heads, + bias, + fuse_scale_mask_softmax, + layernorm_epsilon, + dtype) + self.mlp = BertMLP1D(hidden_size, + mlp_ratio, + activation, + layernorm_epsilon, + dtype, + bias) + + def forward(self, hidden_states, attention_mask): + hidden_states = self.attention(hidden_states, attention_mask) + hidden_states = self.mlp(hidden_states) + + return hidden_states + + +class PipelineBert1D(nn.Module): + + def __init__(self, + vocab_size: int = 50304, + max_position_embeddings: int = 1024, + hidden_size: int = 768, + num_heads: int = 12, + depth: int = 12, + mlp_ratio: float = 4.0, + layernorm_epsilon: float = 1e-5, + activation: Callable = nn.functional.gelu, + padding_idx: int = 0, + dtype: dtype = None, + bias: bool = True, + fuse_scale_mask_softmax: bool = False, + first: bool = False, + last: bool = False, **kwargs): + super().__init__() + self.first = first + self.last = last + + if first: + self.embed = BertEmbedding1D(embedding_dim=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + padding_idx=padding_idx, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype) + self.blocks = nn.ModuleList() + self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if is_using_pp() else 0 + for id_ in range(depth): + self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), + BertTransformerLayer1D( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype, + bias=bias, + fuse_scale_mask_softmax=fuse_scale_mask_softmax, + ) + ) + # self.blocks = nn.ModuleList([ + # BertTransformerLayer1D( + # hidden_size=hidden_size, + # num_heads=num_heads, + # mlp_ratio=mlp_ratio, + # activation=activation, + # layernorm_epsilon=layernorm_epsilon, + # dtype=dtype, + # bias=bias, + # fuse_scale_mask_softmax=fuse_scale_mask_softmax, + # ) for _ in range(depth) + # ]) + + # if self.last: + # self.norm = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + # self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size, dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process + + + def forward(self, hidden_states=None, input_ids=None, attention_mask=None): + if self.first: + hidden_states = self.embed(input_ids) + + for block in self.blocks: + hidden_states = block(hidden_states, attention_mask) + + if self.last: + hidden_states = hidden_states[:, 1, :] + + return hidden_states + + + +def partition_uniform(num_items, pipeline_parallel_size, num_chunks): + assert num_items % num_chunks == 0, \ + "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" + + logger = get_dist_logger() + parts = [[] for _ in range(pipeline_parallel_size)] # 4 + partition_items = num_items // num_chunks # 96 // 2 + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + logger.warning("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + return parts + +def _create_bert_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs): + logger = get_dist_logger() + pipeline_size = 0 + pipeline_rank = 0 + if gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + else: + pipeline_size = 1 + pipeline_rank = 0 + + rank = gpc.get_global_rank() + + parts = partition_uniform(depth, pipeline_size, + num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions + models = [] + for start, end in parts: + model_kwargs['first'] = start == 0 + model_kwargs['last'] = end == depth + model_kwargs['depth'] = end - start + chunk = PipelineBert1D(**model_kwargs).to(get_current_device()) + models.append(chunk) + logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}') + + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + + numel = 0 + for _, param in model.named_parameters(recurse=True): + numel += param.numel() + + if "checkpoint" in model_kwargs.keys(): + if model_kwargs["checkpoint"] is True: + if gpc.get_global_rank() == 0: + assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading" + assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found" + load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs) + + logger.info(f'Rank{rank}/{pipeline_rank} model size in FP16 = {numel * 2 / 1e9} GB') + return model + +def bert_small(**kwargs): + model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, **kwargs) + return _create_bert_pipeline_model(**model_kwargs) diff --git a/energon/model/gpt/__init__.py b/energon/model/gpt/__init__.py new file mode 100644 index 0000000..604a98b --- /dev/null +++ b/energon/model/gpt/__init__.py @@ -0,0 +1,4 @@ +from .gpt import gpt2_small, gpt2_medium, gpt2_large, gpt2_xl, gpt2_8B, gpt3 + + +__all__ = ['gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt2_8B', 'gpt3'] diff --git a/energon/model/gpt/gpt.py b/energon/model/gpt/gpt.py new file mode 100644 index 0000000..69ae90c --- /dev/null +++ b/energon/model/gpt/gpt.py @@ -0,0 +1,470 @@ +import math +from typing import Callable +import os + +import torch +from torch import nn as nn, Tensor, dtype + +from energon.context import ParallelMode +from energon.core import global_context as gpc +from energon.logging import get_dist_logger +from energon.nn.layer.utils import divide, ACT2FN +from energon.nn import Linear1D_Col, Linear1D_Row, Classifier1D +from energon.nn import LayerNorm1D +from energon.nn import VocabParallelEmbedding1D +from energon.utils import get_current_device, is_using_pp +from energon.utils.checkpointing import load_checkpoint + +__all__ = [ + 'GPTEmbedding1D' + 'GPTMLP1D', + 'GPTSelfAttention1D', + 'GPTTransformerLayer1D' +] + + +class GPTEmbedding1D(nn.Module): + + def __init__(self, + embedding_dim: int, + vocab_size: int, + max_position_embeddings: int, + num_tokentypes: int = 0, + padding_idx: int = 0, + dtype: dtype = None) -> None: + super().__init__() + self.word_embeddings = VocabParallelEmbedding1D(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype, skip_tp=True) + self.position_embeddings = VocabParallelEmbedding1D(max_position_embeddings, embedding_dim, dtype=dtype, skip_tp=True) + if num_tokentypes > 0: + self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, embedding_dim, dtype=dtype, skip_tp=True) + else: + self.tokentype_embeddings = None + + @property + def word_embedding_weight(self): + return self.word_embeddings.weight + + def forward(self, input_ids, position_ids=None, tokentype_ids=None): + # padding condition, not for variable length + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) + x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) + if self.tokentype_embeddings is not None and tokentype_ids is not None: + x = x + self.tokentype_embeddings(tokentype_ids) + + return x + + +class GPTSelfAttention1D(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + bias: bool = True, + fuse_scale_mask_softmax: bool = False, + dtype: dtype = None) -> None: + super().__init__() + self.fuse_scale_mask_softmax = fuse_scale_mask_softmax # TODO + self.attention_head_size = divide(dim, num_heads) + self.query_key_value = Linear1D_Col(dim, 3 * dim, bias=bias, dtype=dtype) + + if fuse_scale_mask_softmax: + from colossalai.kernel import FusedScaleMaskSoftmax + from colossalai.kernel.cuda_native.scaled_softmax import \ + AttnMaskType + self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True, + input_in_bf16=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + mask_func=None, + softmax_in_fp32=True, + scale=math.sqrt(self.attention_head_size)) + else: + self.softmax = nn.Softmax(dim=-1) + self.dense = Linear1D_Row(dim, dim, bias=True, dtype=dtype, parallel_input=True) + + def forward(self, x, attention_mask=None): + qkv = self.query_key_value(x) + + # print(f'qkv {qkv.shape}') + + all_head_size = qkv.shape[-1] // 3 + num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads + + new_qkv_shape = qkv.shape[:-1] + \ + (num_attention_heads, 3 * self.attention_head_size) + qkv = qkv.view(new_qkv_shape) + qkv = qkv.permute((0, 2, 1, 3)) + q, k, v = torch.chunk(qkv, 3, dim=-1) + # print(f'qkv {qkv.shape}') # 6 40 128 + + x = torch.matmul(q, k.transpose(-1, -2)) + + if self.fuse_scale_mask_softmax: + x = self.softmax(x, attention_mask) + else: + x = x / math.sqrt(self.attention_head_size) + # causal mask + q_len, k_len = q.size(-2), k.size(-2) + causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, + device=get_current_device())).view(1, 1, q_len, k_len).bool() + x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) + if attention_mask is not None: + x = x + attention_mask + x = self.softmax(x) + + x = torch.matmul(x, v) + x = x.transpose(1, 2) + new_context_layer_shape = x.size()[:-2] + (all_head_size,) + x = x.reshape(new_context_layer_shape) + + x = self.dense(x) + + return x + + +class GPTMLP1D(nn.Module): + + def __init__(self, + dim: int, + mlp_ratio: float, + activation: Callable, + dtype: dtype = None, + bias: bool = True): + super().__init__() + intermediate_dim = int(dim * mlp_ratio) + self.dense_1 = Linear1D_Col(dim, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) + self.activation = activation + self.dense_2 = Linear1D_Row(intermediate_dim, dim, bias=bias, dtype=dtype, parallel_input=True) + + def forward(self, x): + x = self.dense_1(x) + x = self.activation(x) + x = self.dense_2(x) + return x + + +class GPTBlock1D(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: float, + activation: Callable, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None, + bias: bool = True, + apply_post_layernorm: bool = False, + fuse_scale_mask_softmax: bool = False): + super().__init__() + + self.apply_post_layernorm = apply_post_layernorm + # self.norm1 = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + self.norm1 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) + self.attn = GPTSelfAttention1D(dim=dim, + num_heads=num_heads, + bias=bias, + fuse_scale_mask_softmax=fuse_scale_mask_softmax, + dtype=dtype) + + # self.norm2 = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + self.norm2 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) + self.mlp = GPTMLP1D(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dtype=dtype, bias=bias) + + def forward(self, x, attention_mask=None): + if not self.apply_post_layernorm: + residual = x + x = self.norm1(x) + if self.apply_post_layernorm: + residual = x + x = residual + self.attn(x, attention_mask) + + if not self.apply_post_layernorm: + residual = x + x = self.norm2(x) + if self.apply_post_layernorm: + residual = x + x = residual + self.mlp(x) + + return x, attention_mask + + +class GPTLMHead1D(nn.Module): + + def __init__(self, + dim: int, + vocab_size: int, + word_embeding_weight: nn.Parameter = None, + bias: bool = False, + dtype: dtype = None) -> None: + super().__init__() + self.dense = Classifier1D(dim, vocab_size, word_embeding_weight, bias=bias, dtype=dtype) + + @property + def weight(self): + return self.dense.weight + + def forward(self, x): + x = self.dense(x) + return x + + +class GPT1D(nn.Module): + + def __init__(self, + vocab_size: int = 50304, + max_position_embeddings: int = 1024, + dim: int = 768, + num_heads: int = 12, + depth: int = 12, + mlp_ratio: float = 4.0, + layernorm_epsilon: float = 1e-5, + activation: Callable = nn.functional.gelu, + padding_idx: int = 0, + dtype: dtype = None, + bias: bool = True, + apply_post_layernorm: bool = False, + fuse_scale_mask_softmax: bool = False) -> None: + super().__init__() + self.embed = GPTEmbedding1D(embedding_dim=dim, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + padding_idx=padding_idx, + dtype=dtype) + self.blocks = nn.ModuleList() + self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + for id_ in range(depth): + self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), + GPTBlock1D( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype, + bias=bias, + apply_post_layernorm=apply_post_layernorm, + fuse_scale_mask_softmax=fuse_scale_mask_softmax, + ) + ) + # self.blocks = nn.ModuleList([ + # GPTBlock1D( + # dim=dim, + # num_heads=num_heads, + # mlp_ratio=mlp_ratio, + # activation=activation, + # layernorm_epsilon=layernorm_epsilon, + # dtype=dtype, + # bias=bias, + # apply_post_layernorm=apply_post_layernorm, + # fuse_scale_mask_softmax=fuse_scale_mask_softmax, + # ) for _ in range(depth) + # ]) + # self.norm = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + self.norm = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) + self.head = GPTLMHead1D(dim=dim, + vocab_size=vocab_size, + word_embeding_weight=self.embed.word_embedding_weight, + dtype=dtype) + + def forward(self, input_ids, attention_mask=None): + x = self.embed(input_ids) + + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + for block in self.blocks: + x, attention_mask = block(x, attention_mask) + + x = self.head(self.norm(x)) + + return x + + +class PipelineGPT1D(nn.Module): + + def __init__(self, + vocab_size: int = 50257, + max_position_embeddings: int = 1024, + dim: int = 768, + num_heads: int = 12, + depth: int = 12, + mlp_ratio: float = 4.0, + layernorm_epsilon: float = 1e-5, + activation: Callable = nn.functional.gelu, + padding_idx: int = 0, + dtype: dtype = None, + bias: bool = True, + apply_post_layernorm: bool = False, + fuse_scale_mask_softmax: bool = False, + first: bool = False, + last: bool = False, **kwargs): + super().__init__() + self.first = first + self.last = last + if first: + self.embed = GPTEmbedding1D(embedding_dim=dim, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + padding_idx=padding_idx, + dtype=dtype) + self.blocks = nn.ModuleList() + self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if is_using_pp() else 0 + for id_ in range(depth): + self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), + GPTBlock1D( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype, + bias=bias, + apply_post_layernorm=apply_post_layernorm, + fuse_scale_mask_softmax=fuse_scale_mask_softmax, + ) + ) + # self.blocks = nn.ModuleList([ + # GPTBlock1D( + # dim=dim, + # num_heads=num_heads, + # mlp_ratio=mlp_ratio, + # activation=activation, + # layernorm_epsilon=layernorm_epsilon, + # dtype=dtype, + # bias=bias, + # apply_post_layernorm=apply_post_layernorm, + # fuse_scale_mask_softmax=fuse_scale_mask_softmax, + # ) for _ in range(depth) + # ]) + if self.last: + # self.norm = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + self.norm = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) + self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size, + dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process + + def forward(self, hidden_states=None, input_ids=None, attention_mask=None): + if self.first: + hidden_states = self.embed(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # Adapted from huggingface + if attention_mask is not None: + if self.first: + batch_size = input_ids.shape[0] + else: + batch_size = hidden_states.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + for block in self.blocks: + hidden_states, attention_mask = block(hidden_states, attention_mask) + + if self.last: + hidden_states = self.head(self.norm(hidden_states)) + + return hidden_states + + +def partition_uniform(num_items, pipeline_parallel_size, num_chunks): + assert num_items % num_chunks == 0, \ + "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" + + logger = get_dist_logger() + parts = [[] for _ in range(pipeline_parallel_size)] # 4 + partition_items = num_items // num_chunks # 96 // 2 + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + logger.warning("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + return parts + + +def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs): + logger = get_dist_logger() + pipeline_size = 0 + pipeline_rank = 0 + if gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + else: + pipeline_size = 1 + pipeline_rank = 0 + + rank = gpc.get_global_rank() + + parts = partition_uniform(depth, pipeline_size, + num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions + models = [] + for start, end in parts: + model_kwargs['first'] = start == 0 + model_kwargs['last'] = end == depth + model_kwargs['depth'] = end - start + chunk = PipelineGPT1D(**model_kwargs).to(get_current_device()) + models.append(chunk) + logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}') + + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + + numel = 0 + for _, param in model.named_parameters(recurse=True): + numel += param.numel() + if "checkpoint" in model_kwargs.keys(): + if model_kwargs["checkpoint"] is True: + if gpc.get_global_rank() == 0: + assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading" + print(model_kwargs["checkpoint_path"]) + assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found" + load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs) + logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') + return model + + +def gpt2_small(**kwargs): + model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) + + +def gpt2_medium(**kwargs): + model_kwargs = dict(dim=1024, depth=24, num_heads=8, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) + + +def gpt2_large(**kwargs): + model_kwargs = dict(dim=1536, depth=36, num_heads=12, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) + + +def gpt2_xl(**kwargs): + model_kwargs = dict(dim=1600, depth=48, num_heads=16, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) + + +def gpt2_8B(**kwargs): + model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) + + +def gpt3(**kwargs): + model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) diff --git a/energon/server/__init__.py b/energon/server/__init__.py new file mode 100644 index 0000000..1249f10 --- /dev/null +++ b/energon/server/__init__.py @@ -0,0 +1,2 @@ +from .engine_server import launch_engine +from .worker_server import launch_worker \ No newline at end of file diff --git a/energon/server/engine_server.py b/energon/server/engine_server.py new file mode 100644 index 0000000..e688804 --- /dev/null +++ b/energon/server/engine_server.py @@ -0,0 +1,81 @@ +import os +import torch +import uvicorn +from fastapi import FastAPI +from fastapi import Response +import torch.distributed.rpc as rpc +from energon.engine import InferenceEngine +from energon.model import gpt2_small, gpt2_medium, gpt2_large, gpt2_xl, gpt2_8B, gpt3 +from energon.model import bert_small + +MODEL_CLASSES = { + "bert_small": bert_small, + "gpt2_small": gpt2_small, + "gpt2_medium": gpt2_medium, + "gpt2_large": gpt2_large, + "gpt2_xl": gpt2_xl, + "gpt2_8B": gpt2_8B, + "gpt3": gpt3 +} + +app = FastAPI() # 创建 api 对象 + + + +@app.get("/") # 根路由 +def root(): + return {"200"} + +@app.get("/run") +def run(): + # a string arguement to produce sample + input_ids = torch.randint(1, 10, (32, 40), dtype=torch.int64) + attention_mask = torch.randint(0, 1, (32, 1, 40, 40), dtype=torch.int64) + hidden_states = None + sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) + + output = engine.run(sample) + output = output.to_here() + print(output) + return {"To return the string result."} + + +@app.get("/shutdown") +async def shutdown(): + engine.clear() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def launch_engine(model_name, + model_type, + max_batch_size: int = 1, + tp_init_size: int = -1, + pp_init_size: int = -1, + host: str = "localhost", + port: int = 29500, + dtype = torch.float, + checkpoint = None, + server_host = "localhost", + server_port = 8005, + log_level = "info" + ): + + model_config = {'dtype': dtype} + global engine + engine = InferenceEngine(MODEL_CLASSES[model_name], + model_config, + model_type, + max_batch_size = max_batch_size, + tp_init_size = tp_init_size, + pp_init_size = pp_init_size, + host = host, + port = port, + dtype = dtype, + checkpoint = checkpoint) + + global server + config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) + server = uvicorn.Server(config=config) + server.run() \ No newline at end of file diff --git a/energon/server/worker_server.py b/energon/server/worker_server.py new file mode 100644 index 0000000..08567d1 --- /dev/null +++ b/energon/server/worker_server.py @@ -0,0 +1,52 @@ +import uvicorn +from fastapi import FastAPI +import torch.distributed.rpc as rpc +from energon.initialize import launch_from_multiprocess + +app = FastAPI() # 创建 api 对象 + + + +@app.get("/") # 根路由 +def root(): + return {"200"} + +# @app.get("/start/{tp_size}") +# def init(tp_size: int, pp_size: int, backend: str, seed: int, verbose: bool, rank: int, local_rank: int, host: str, port: int): +# # http://127.0.0.1:8005/start/1?pp_size=1&backend=nccl&seed=1024&verbose=true&rank=0&local_rank=0&host=localhost&port=29500 +# # http://127.0.0.1:8005/start/1?pp_size=1&backend=nccl&seed=1024&verbose=true&rank=0&local_rank=0&host=localhost&port=29500 +# world_size = tp_size * pp_size + +# os.environ['MASTER_ADDR'] = host +# os.environ['MASTER_PORT'] = f'{port}' +# launch_from_multiprocess(tp_size, pp_size, backend, seed, verbose, rank, local_rank, world_size, host, port) +# WORKER_NAME = "wok{}" +# rpc_backend_options=rpc.TensorPipeRpcBackendOptions( +# num_worker_threads=16) +# rpc.init_rpc(WORKER_NAME.format(rank), rank=rank, world_size=world_size, rpc_backend_options=rpc_backend_options) +# rpc.shutdown() +# # print(f'{WORKER_NAME.format(rank)} Start!') +# return {f'{WORKER_NAME.format(rank)} Start!'} + +@app.get("/shutdown") +async def shutdown(): + rpc.shutdown() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def launch_worker(host="127.0.0.1", port=29500, tp_init_size=1, pp_init_size=1, backend="nccl", seed=1024, verbose=True, rank=0, local_rank=0, + server_host="127.0.0.1", server_port=8005, log_level="info"): + + world_size = tp_init_size * pp_init_size + + launch_from_multiprocess(tp_init_size, pp_init_size, backend, seed, verbose, rank, local_rank, world_size, host, port) + WORKER_NAME = "wok{}" + rpc_backend_options=rpc.TensorPipeRpcBackendOptions(num_worker_threads=16) + rpc.init_rpc(WORKER_NAME.format(rank), rank=rank, world_size=world_size, rpc_backend_options=rpc_backend_options) + + global server + config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) + server = uvicorn.Server(config=config) + server.run() \ No newline at end of file