Skip to content

Commit fac9b78

Browse files
committed
num items in batch
1 parent 4950162 commit fac9b78

File tree

5 files changed

+15
-15
lines changed

5 files changed

+15
-15
lines changed

src/liger_kernel/chunked_loss/cpo_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
1010

1111
@staticmethod
12-
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
12+
def preference_loss_fn(chosen_logps, rejected_logps, num_items_in_batch, beta=0.1):
1313
"""
1414
Paper: https://arxiv.org/pdf/2401.08417
1515
@@ -28,11 +28,11 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
2828
Args:
2929
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
3030
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31-
full_target (torch.Tensor): Non chunked full target tensor
31+
num_items_in_batch (int): Number of items in the batch.
3232
beta (float): Weight for the CPO loss
3333
"""
3434
logits = beta * (chosen_logps - rejected_logps)
35-
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
35+
loss = F.logsigmoid(logits).sum() / (num_items_in_batch // 2)
3636
return loss
3737

3838
@staticmethod

src/liger_kernel/chunked_loss/dpo_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
1212
def preference_loss_fn(
1313
chosen_logps,
1414
rejected_logps,
15-
full_target,
15+
num_items_in_batch,
1616
ref_chosen_logps=None,
1717
ref_rejected_logps=None,
1818
beta=0.1,
@@ -34,7 +34,7 @@ def preference_loss_fn(
3434
Args:
3535
chosen_logps: Log probabilities of chosen tokens (batch_size,)
3636
rejected_logps: Log probabilities of rejected tokens (batch_size,)
37-
full_target: Non chunked full target tensor
37+
num_items_in_batch (int): Number of items in the batch.
3838
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
3939
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
4040
beta: Weight for the direct preference loss
@@ -49,7 +49,7 @@ def preference_loss_fn(
4949
rejected_logratios = rejected_logps - ref_rejected_logps
5050

5151
logits_diff = beta * (chosen_logratios - rejected_logratios)
52-
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52+
loss = -F.logsigmoid(logits_diff).sum() / (num_items_in_batch // 2)
5353
return loss
5454

5555
@staticmethod

src/liger_kernel/chunked_loss/fused_linear_preference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _compute_loss(
387387
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
388388

389389
preference_loss_outputs = preference_loss_fn(
390-
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
390+
chosen_logps, rejected_logps, full_target.shape[0], beta=beta, **loss_kwargs
391391
)
392392
if isinstance(preference_loss_outputs, tuple):
393393
preference_loss, *aux_outputs = preference_loss_outputs

src/liger_kernel/chunked_loss/orpo_loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
1010

1111
@staticmethod
12-
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
12+
def preference_loss_fn(chosen_logps, rejected_logps, num_items_in_batch, beta=0.1):
1313
"""
1414
Paper: https://arxiv.org/pdf/2403.07691
1515
@@ -28,21 +28,21 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
2828
Args:
2929
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
3030
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31-
full_target (torch.Tensor): Non chunked full target tensor
31+
num_items_in_batch (int): Number of items in the batch.
3232
beta (float): Weight for the odds ratio loss.
3333
"""
3434
log_odds = (chosen_logps - rejected_logps) - (
3535
torch.log1p(-torch.exp(chosen_logps))
3636
- torch.log1p(-torch.exp(rejected_logps))
3737
)
3838
ratio = F.logsigmoid(log_odds)
39-
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
39+
loss = beta * ratio.sum() / (num_items_in_batch // 2)
4040

4141
chosen_rewards = beta * chosen_logps
4242
rejected_rewards = beta * rejected_logps
4343

44-
log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
45-
log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
44+
log_odds_ratio = torch.sum(ratio) / (num_items_in_batch // 2)
45+
log_odds_chosen = torch.sum(log_odds) / (num_items_in_batch // 2)
4646

4747
return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
4848

src/liger_kernel/chunked_loss/simpo_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
1010

1111
@staticmethod
1212
def preference_loss_fn(
13-
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
13+
chosen_logps, rejected_logps, num_items_in_batch, beta=0.1, gamma=0.5
1414
):
1515
"""
1616
Paper: https://arxiv.org/pdf/2405.14734
@@ -30,12 +30,12 @@ def preference_loss_fn(
3030
Args:
3131
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
3232
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
33-
full_target: Non chunked full target tensor
33+
num_items_in_batch (int): Number of items in the batch.
3434
beta (float): beta weight
3535
gamma (float): gemma margin term
3636
"""
3737
logits = beta * (chosen_logps - rejected_logps) - gamma
38-
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
38+
loss = F.logsigmoid(logits).sum() / (num_items_in_batch // 2)
3939
return loss
4040

4141
@staticmethod

0 commit comments

Comments
 (0)