Skip to content

added batch norm #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torch.nn as nn
import torch.distributed as dist

import os
import datetime
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
class MyCNN(nn.Module):
def __init__(self):
super(MyCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=3)
# Replace BatchNorm with SyncBatchNorm for distributed training
self.bn1 = nn.SyncBatchNorm(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
self.bn2 = nn.SyncBatchNorm(20)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(1440, 10) # Assuming output size after convolutions

def forward(self, x):
x = self.pool(torch.relu(self.bn1(self.conv1(x))))
x = self.pool(torch.relu(self.bn2(self.conv2(x))))
x = x.view(-1, 1440) # Flatten for fully-connected layer
x = self.fc(x)
return x


def main():
# Initialize distributed training (replace with your specific initialization logic)
dist.init_process_group(backend="nccl", world_size=1, rank=0) # Example for 4 processes
print('main')
model = MyCNN()
model.to('cuda')
# Wrap the model with DistributedDataParallel (DDP) for distributed training
model = nn.parallel.DistributedDataParallel(model)

# Define your optimizer, loss function, and training loop (omitted for brevity)

# ... training code ...

dist.destroy_process_group() # Clean up after training


if __name__ == "__main__":
main()
104 changes: 104 additions & 0 deletions src/liger_kernel/ops/SyncBatchNorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import calculate_settings, ensure_contiguous


@triton.jit
def _batch_norm_forward_kernel(
X_ptr,
X_batch_stride,
X_row_stride,
Y_ptr,
Y_batch_stride,
Y_row_stride,
# MEAN_ptr,
# MEAN_row_stride,
# VARIANCE_ptr,
# VARIANCE_row_stride,
# AXIS_ptr,
# AXIS_row_stride,
n_cols,
eps,
mean,
axis,
scale,
offset,
variance,
BLOCK_SIZE: tl.constexpr,
):
pass
row_idx = tl.program_id(1)
batch_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE) * batch_idx

mask = True # col_offsets < n_cols
X_ptr += (X_row_stride) * (row_idx) # +X_batch_stride
Y_ptr += (Y_row_stride) * (row_idx) # +Y_batch_stride
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)

inv = tl.rsqrt(tl.load(variance) + eps)

if scale is not None:
inv = inv * tl.load(scale)

res = -tl.load(mean) * inv
if offset is not None:
res = res + tl.load(offset)

tl.store(Y_ptr + col_offsets, X_row * inv + res)


def batch_norm_forward(X, axis, offset, scale, eps, mean, variance):
batch, n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.zeros((batch, n_rows, n_cols), dtype=X.dtype, device=X.device)

_batch_norm_forward_kernel[
(
batch * n_cols * n_rows,
n_rows,
)
]( # [(n_rows,)]
X,
X.stride(0),
X.stride(1),
Y,
Y.stride(0),
Y.stride(1),
n_cols,
eps,
mean,
axis,
scale,
offset,
variance,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)

return Y


class LigerBatchNormFunction(torch.autograd.Function):
variance = torch.tensor([1] * 32 * 8).to("cuda")
scale = torch.tensor([1] * 32 * 8).to("cuda")
offset = torch.tensor([0] * 32 * 8).to("cuda")
mean = torch.tensor([0] * 32 * 8).to("cuda")

@staticmethod
@ensure_contiguous
def forward(ctx, X, axis, eps):
X = batch_norm_forward(
X,
axis,
LigerBatchNormFunction.offset,
LigerBatchNormFunction.scale,
eps,
LigerBatchNormFunction.mean,
LigerBatchNormFunction.variance,
)

return X
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from liger_kernel.ops.jsd import LigerJSDFunction
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
from liger_kernel.ops.SyncBatchNorm import LigerBatchNormFunction
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
from liger_kernel.ops.rope import LigerRopeFunction
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
Expand All @@ -19,6 +20,7 @@
liger_rms_norm = LigerRMSNormFunction.apply
liger_rope = LigerRopeFunction.apply
liger_layer_norm = LigerLayerNormFunction.apply
liger_batch_norm = LigerBatchNormFunction.apply
liger_kl_div = LigerKLDivLossFunction.apply
liger_jsd = LigerJSDFunction.apply
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
Expand Down
62 changes: 62 additions & 0 deletions test/transformers/test_batched_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

import keras
import numpy
import pytest
import torch

from liger_kernel.transformers.functional import liger_batch_norm

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"


@pytest.mark.parametrize(
"hidden_size",
[
32,
32,
32,
32,
],
)
@pytest.mark.parametrize(
"batch_size, seq_len",
[
(2, 32),
(8, 32),
(2, 32),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float32, 1e-5, 1e-5),
],
)
def test_liger_bacthed_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol):
torch.manual_seed(0)

x = torch.randn(
batch_size, seq_len, hidden_size, dtype=dtype, device="cuda", requires_grad=True
)
keras_ln = keras.layers.BatchNormalization(epsilon=1e-6)

keras_ln.training = False

axis = -1

eps = 1e-6

liger_output = liger_batch_norm(x, axis, eps)

x = x.detach().cpu().numpy()
keras_output = keras_ln(x)

assert torch.allclose(
liger_output,
torch.Tensor(numpy.array(keras_output)).to("cuda"),
atol=atol,
rtol=rtol,
)
Loading