diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index b97db7c61..1da1643a6 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -715,33 +715,33 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 -kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,20.53521728515625,19.908370971679688,20.64090919494629,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:39,0.5.2 -kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,37.11443328857422,37.1072883605957,37.12157440185547,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:39,0.5.2 -kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,76.66329956054688,76.66329956054688,76.66329956054688,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:39,0.5.2 -kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,172.98681640625,172.98681640625,172.98681640625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:39,0.5.2 -kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,303.8100280761719,303.8100280761719,303.8100280761719,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:39,0.5.2 -kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,8.131551742553711,8.122809410095215,8.135846138000488,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:58,0.5.2 -kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,17.13974380493164,13.816153526306152,19.821325302124023,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:58,0.5.2 -kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,32.35935974121094,28.95905876159668,32.686336517333984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:58,0.5.2 -kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,59.52479934692383,59.52479934692383,59.52479934692383,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:58,0.5.2 -kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,129.97698974609375,129.97698974609375,129.97698974609375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:12:58,0.5.2 -kto_loss,liger,full,speed,ms,B,Batch Size (B),2,19.929119110107422,19.372455596923828,20.696868896484375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:19,0.5.2 -kto_loss,liger,full,speed,ms,B,Batch Size (B),4,38.00328063964844,37.99269104003906,38.01387023925781,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:19,0.5.2 -kto_loss,liger,full,speed,ms,B,Batch Size (B),8,75.40016174316406,75.40016174316406,75.40016174316406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:19,0.5.2 -kto_loss,liger,full,speed,ms,B,Batch Size (B),16,148.1293487548828,148.1293487548828,148.1293487548828,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:19,0.5.2 -kto_loss,liger,full,speed,ms,B,Batch Size (B),32,297.30621337890625,297.30621337890625,297.30621337890625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:19,0.5.2 -kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,13.964447975158691,13.904864311218262,13.98806381225586,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:38,0.5.2 -kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,25.264448165893555,25.253772735595703,25.288410186767578,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:38,0.5.2 -kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,48.33251190185547,48.33251190185547,48.33251190185547,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:38,0.5.2 -kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,94.19913482666016,94.19913482666016,94.19913482666016,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:38,0.5.2 -kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,186.9466552734375,186.9466552734375,186.9466552734375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:13:38,0.5.2 -kto_loss,liger,full,memory,MB,B,Batch Size (B),2,3543.0029296875,3543.0029296875,3543.0029296875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:01,0.5.2 -kto_loss,liger,full,memory,MB,B,Batch Size (B),4,4306.2548828125,4306.2548828125,4306.2548828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:01,0.5.2 -kto_loss,liger,full,memory,MB,B,Batch Size (B),8,4318.2744140625,4318.2744140625,4318.2744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:01,0.5.2 -kto_loss,liger,full,memory,MB,B,Batch Size (B),16,4342.3134765625,4342.3134765625,4342.3134765625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:01,0.5.2 -kto_loss,liger,full,memory,MB,B,Batch Size (B),32,4390.3916015625,4390.3916015625,4390.3916015625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:01,0.5.2 -kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,5051.99560546875,5051.99560546875,5051.99560546875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:22,0.5.2 -kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6809.51220703125,6809.51220703125,6809.51220703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:22,0.5.2 -kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,10321.544921875,10321.544921875,10321.544921875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:22,0.5.2 -kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,17351.611328125,17351.611328125,17351.611328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:22,0.5.2 -kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,31411.744140625,31411.744140625,31411.744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-20 07:14:22,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,4.223455905914307,4.183884620666504,4.23199987411499,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:41,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,8.400144577026367,8.322336196899414,8.437881469726562,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:41,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,16.128929138183594,16.08905601501465,16.155744552612305,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:41,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,32.88691329956055,32.87322235107422,32.9140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:41,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,64.56556701660156,64.56556701660156,64.56556701660156,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:41,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,3.885007858276367,3.872652769088745,3.898591995239258,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:55,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,7.364704132080078,7.353612899780273,7.40869140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:55,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,14.353952407836914,14.347923278808594,14.392550468444824,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:55,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,28.60825538635254,28.41696548461914,28.709611892700195,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:55,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,58.51824188232422,58.51824188232422,58.51824188232422,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:13:55,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),2,4.638735771179199,4.555334091186523,4.652671813964844,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:10,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),4,8.78553581237793,8.72441577911377,8.79651165008545,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:10,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),8,16.917583465576172,16.87615966796875,16.953439712524414,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:10,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),16,33.683902740478516,33.61384201049805,33.753963470458984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:10,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),32,66.50137329101562,66.50137329101562,66.50137329101562,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:10,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,6.3572797775268555,6.338803291320801,6.376518726348877,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:25,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,11.686159133911133,11.640652656555176,11.694182395935059,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:25,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,22.35055923461914,22.22890281677246,22.556838989257812,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:25,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,44.56269073486328,44.557830810546875,44.56754684448242,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:25,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,88.43180847167969,88.43180847167969,88.43180847167969,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:25,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2585.24609375,2585.24609375,2585.24609375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:40,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3348.49609375,3348.49609375,3348.49609375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:40,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),8,3360.51171875,3360.51171875,3360.51171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:40,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),16,3384.54296875,3384.54296875,3384.54296875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:40,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),32,3432.60546875,3432.60546875,3432.60546875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:40,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4343.2470703125,4343.2470703125,4343.2470703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:55,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.2646484375,6099.2646484375,6099.2646484375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:55,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9614.296875,9614.296875,9614.296875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:55,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.36328125,16643.36328125,16643.36328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:55,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.49609375,30703.49609375,30703.49609375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 NVL,2024-12-21 02:14:55,0.5.2 diff --git a/benchmark/scripts/benchmark_kto_loss.py b/benchmark/scripts/benchmark_kto_loss.py index 9454fab50..f4fa2dc57 100644 --- a/benchmark/scripts/benchmark_kto_loss.py +++ b/benchmark/scripts/benchmark_kto_loss.py @@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -class TorchKTOLoss(torch.nn.Module): +class TorchLMHeadKTO(torch.nn.Module): def __init__( self, H: int, @@ -29,11 +29,8 @@ def __init__( ref_bias: bool = False, ignore_index: int = -100, beta: float = 0.1, - policy_KL_logps: torch.FloatTensor = None, - ref_KL_logps: torch.FloatTensor = None, ): from test.chunked_loss.test_kto_loss import HFKTOLoss - super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype @@ -41,16 +38,14 @@ def __init__( self.ref_lin = torch.nn.Linear( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) - self.kto_loss = HFKTOLoss( + self.KTO_loss = HFKTOLoss( ignore_index=ignore_index, beta=beta, use_ref_model=True, - policy_KL_logps=policy_KL_logps, - ref_KL_logps=ref_KL_logps, ).get_batch_loss_metrics - def forward(self, x, ref_x, y, preference_labels): - return self.kto_loss( + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( weight=self.lin.weight, _input=x, target=y, @@ -59,10 +54,11 @@ def forward(self, x, ref_x, y, preference_labels): ref_weight=self.ref_lin.weight, ref_bias=self.ref_lin.bias, preference_labels=preference_labels, - )[0] + kl=kl, + ) -class LigerKTOLoss(torch.nn.Module): +class LigerLMHeadKTO(torch.nn.Module): def __init__( self, H: int, @@ -72,8 +68,6 @@ def __init__( ref_bias: bool = False, ignore_index: int = -100, beta: float = 0.1, - policy_KL_logps: torch.FloatTensor = None, - ref_KL_logps: torch.FloatTensor = None, ): super().__init__() self.lin = torch.nn.Linear( @@ -82,25 +76,24 @@ def __init__( self.ref_lin = torch.nn.Linear( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) - self.kto_loss = LigerFusedLinearKTOLoss( + self.KTO_loss = LigerFusedLinearKTOLoss( ignore_index=ignore_index, beta=beta, use_ref_model=True, - policy_KL_logps=policy_KL_logps, - ref_KL_logps=ref_KL_logps, ) - def forward(self, x, ref_x, y, preference_labels): - return self.kto_loss( + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( _input=x, lin_weight=self.lin.weight, target=y, - bias=self.lin.bias, preference_labels=preference_labels, + bias=self.lin.bias, ref_input=ref_x, ref_weight=self.ref_lin.weight, ref_bias=self.ref_lin.bias, - )[0] + kl=kl, + ) def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -114,7 +107,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO ignore_index = input.extra_benchmark_config["ignore_index"] provider = input.kernel_provider - torch_kto_loss = TorchKTOLoss( + torch_kto_loss = TorchLMHeadKTO( H=H, V=V, dtype=dtype, @@ -124,7 +117,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO beta=beta, ).to(device) - liger_kto_loss = LigerKTOLoss( + liger_kto_loss = LigerLMHeadKTO( H=H, V=V, dtype=dtype, @@ -187,7 +180,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu provider = input.kernel_provider mode = input.kernel_operation_mode - torch_kto_loss = TorchKTOLoss( + torch_kto_loss = TorchLMHeadKTO( H=H, V=V, dtype=dtype, @@ -195,7 +188,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu ignore_index=ignore_index, bias=bias, ).to(device) - liger_kto_loss = LigerKTOLoss( + liger_kto_loss = LigerLMHeadKTO( H=H, V=V, dtype=dtype, diff --git a/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py b/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py index 50a900600..fe37230a6 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py @@ -20,14 +20,12 @@ def forward( _input, weight, target, + preference_labels, bias=None, loss_fn=None, chunk_size=1, ignore_index=-100, - alpha=1.0, - beta=0.1, compiled=True, - preference_labels=None, use_ref_model=False, ref_input=None, ref_weight=None, @@ -74,91 +72,48 @@ def forward( # Loss to be accumulated loss_acc = torch.zeros((), device=_input.device) - # Metrics to be recorded - policy_chosen_logps = [] - policy_rejected_logps = [] - policy_chosen_logits_mean = torch.zeros((), device=_input.device) - policy_rejected_logits_mean = torch.zeros((), device=_input.device) - policy_nll_loss = torch.zeros((), device=_input.device) - aggregated_aux_outputs = [] # aggregated aux outputs from all chunks - compute_loss = partial( LigerFusedLinearUnpairedPreferenceBase._compute_loss, preference_loss_fn=loss_fn, - ignore_index=ignore_index, - alpha=alpha, - beta=beta, full_target=target, + ignore_index=ignore_index, use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, - preference_labels=preference_labels, **loss_kwargs, ) def fused_fwd_bwd( - input_chunk, target_chunk, ref_input_chunk, preference_labels_chunk + input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk ): """ Fused forward and backward pass for a chunk of input and target. """ - if bias is not None: - return torch.func.grad_and_value( - compute_loss, argnums=(0, 1, 3), has_aux=True - )( - input_chunk, - weight, - target_chunk, - bias, - ref_input_chunk=ref_input_chunk, - preference_labels=preference_labels_chunk, - ) - else: - return torch.func.grad_and_value( - compute_loss, argnums=(0, 1), has_aux=True - )( - input_chunk, - weight, - target_chunk, - ref_input_chunk=ref_input_chunk, - preference_labels=preference_labels_chunk, - ) + argnums = (0, 1, 4) if bias is not None else (0, 1) + return torch.func.grad_and_value( + compute_loss, argnums=argnums, has_aux=False + )( + input_chunk, + weight, + target_chunk, + preference_labels_chunk, + bias, + ref_input_chunk=ref_input_chunk, + ) def accumulate_chunk( input_chunk, target_chunk, - ref_input_chunk=None, preference_labels_chunk=None, + ref_input_chunk=None, ): + (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), ( + chunk_loss + ) = fused_fwd_bwd( + input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk + ) if bias is not None: - (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_loss, - ( - chunk_chosen_logps, - chunk_rejected_logps, - chunk_chosen_logits_mean, - chunk_rejected_logits_mean, - chunk_policy_nll_loss, - *aux_outputs, - ), - ) = fused_fwd_bwd( - input_chunk, target_chunk, ref_input_chunk, preference_labels_chunk - ) - grad_bias.add_(chunk_grad_bias) # accumulate bias gradient - else: - (chunk_grad_input, chunk_grad_weight), ( - chunk_loss, - ( - chunk_chosen_logps, - chunk_rejected_logps, - chunk_chosen_logits_mean, - chunk_rejected_logits_mean, - chunk_policy_nll_loss, - *aux_outputs, - ), - ) = fused_fwd_bwd( - input_chunk, target_chunk, ref_input_chunk, preference_labels_chunk - ) + grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient # Accumulate gradients grad_weight.add_(chunk_grad_weight) @@ -167,31 +122,6 @@ def accumulate_chunk( # Accumulate loss loss_acc.add_(chunk_loss) - # Accumulate metrics - policy_chosen_logps.append(chunk_chosen_logps) - policy_rejected_logps.append(chunk_rejected_logps) - policy_chosen_logits_mean.add_(chunk_chosen_logits_mean) - policy_rejected_logits_mean.add_(chunk_rejected_logits_mean) - policy_nll_loss.add_(chunk_policy_nll_loss) - - # aux_outputs - # Initialize storage for aux_outputs - if len(aggregated_aux_outputs) == 0: - for aux in aux_outputs: - if aux.ndim == 0: - aggregated_aux_outputs.append( - torch.zeros((), device=aux.device) - ) - else: - aggregated_aux_outputs.append([]) - - # Process each aux_output - for i, aux in enumerate(aux_outputs): - if aux.ndim == 0: - aggregated_aux_outputs[i].add_(aux) - else: - aggregated_aux_outputs[i].append(aux) - if compiled: fused_fwd_bwd = torch.compile(fused_fwd_bwd) @@ -232,30 +162,15 @@ def accumulate_chunk( # accumulate loss, gradients, and metrics accumulate_chunk( - input_chunk, target_chunk, ref_input_chunk, preference_labels_chunk + input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk ) - policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0) - policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0) - - # Aggregate aux outputs lists into tensors - for i, aux in enumerate(aggregated_aux_outputs): - if isinstance(aux, list): - aggregated_aux_outputs[i] = torch.cat(aux, dim=0) - ctx.save_for_backward( torch.cat(grad_inputs, dim=0), grad_weight, grad_bias, ) - return_vars = ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits_mean, - policy_rejected_logits_mean, - policy_nll_loss, - ) - return loss_acc, (*return_vars, *aggregated_aux_outputs) + return loss_acc @staticmethod def backward(ctx, *grad_output): @@ -267,7 +182,7 @@ def backward(ctx, *grad_output): grad_weight = grad_weight * grad_output[0][0] grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None - return grad_input, grad_weight, None, grad_bias, None, None, None + return grad_input, grad_weight, None, None, grad_bias @staticmethod def chunk_forward( @@ -276,53 +191,36 @@ def chunk_forward( target_chunk, bias=None, ignore_index=-100, - preference_labels=None, ): - # Data is not ordered, so we need to use preference_labels to separate chosen and rejected logits_chunk = input_chunk @ weight.t() if bias is not None: logits_chunk = logits_chunk + bias log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) + loss_mask_chunk = target_chunk != ignore_index + label_chunk = torch.where(loss_mask_chunk, target_chunk, 0) - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( + per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( -1 ) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1) - chosen_logps = average_log_prob[preference_labels == 1] - rejected_logps = average_log_prob[preference_labels == 0] - - chosen_logits = logits_chunk[preference_labels == 1] - rejected_logits = logits_chunk[preference_labels == 0] - policy_nll_loss = torch.tensor(0.0, device=chosen_logps.device) - - return ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - policy_nll_loss, - ) + return average_log_prob_chunk @staticmethod def _compute_loss( input_chunk, weight, target_chunk, + preference_labels_chunk, bias=None, preference_loss_fn=None, full_target=None, ignore_index=-100, - alpha=1.0, - beta=0.1, use_ref_model=False, ref_input_chunk=None, ref_weight=None, ref_bias=None, - preference_labels=None, **loss_kwargs, ): """ @@ -335,76 +233,32 @@ def _compute_loss( bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). ignore_index (int): Index to ignore for loss computation. - alpha (float): Weight for the NLL loss. - beta (float): Weight for the preference loss. use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Additional arguments for the loss function. """ - ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - policy_nll_loss, - ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( + average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( input_chunk, weight, target_chunk, bias=bias, ignore_index=ignore_index, - preference_labels=preference_labels, - ) - - num_chosen = (preference_labels == 1).sum() - num_rejected = (preference_labels == 0).sum() - - chosen_logits_mean = ( - chosen_logits.sum() / (num_chosen * input_chunk.shape[1] * weight.shape[0]) - if num_chosen > 0 - else torch.tensor(0.0, device=chosen_logits.device) - ) - rejected_logits_mean = ( - rejected_logits.sum() - / (num_rejected * input_chunk.shape[1] * weight.shape[0]) - if num_rejected > 0 - else torch.tensor(0.0, device=rejected_logits.device) ) if use_ref_model: with torch.no_grad(): - ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - policy_nll_loss, - ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( + ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( ref_input_chunk, ref_weight, target_chunk, ref_bias, ignore_index=ignore_index, - preference_labels=preference_labels, ) - loss_kwargs["ref_chosen_logps"] = ref_chosen_logps - loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk - preference_loss_outputs = preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs - ) - if isinstance(preference_loss_outputs, tuple): - preference_loss, *aux_outputs = preference_loss_outputs - else: - preference_loss, aux_outputs = preference_loss_outputs, [] - - loss = 0 - preference_loss - return_vars = ( - chosen_logps, - rejected_logps, - chosen_logits_mean, - rejected_logits_mean, - policy_nll_loss, + preference_loss_chunk = preference_loss_fn( + average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs ) - return loss, (*return_vars, *aux_outputs) + + return preference_loss_chunk diff --git a/src/liger_kernel/chunked_loss/kto_loss.py b/src/liger_kernel/chunked_loss/kto_loss.py index 19f478179..a8cb9115c 100644 --- a/src/liger_kernel/chunked_loss/kto_loss.py +++ b/src/liger_kernel/chunked_loss/kto_loss.py @@ -10,14 +10,12 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase): @staticmethod def preference_loss_fn( - chosen_logps, - rejected_logps, + average_log_prob_chunk, + preference_labels_chunk, full_target, - ref_chosen_logps=None, - ref_rejected_logps=None, + ref_average_log_prob_chunk=None, beta=0.1, - policy_KL_logps=None, - ref_KL_logps=None, + kl=None, ): """ Implements the Kahneman-Tversky Optimization (KTO) loss function. @@ -63,38 +61,14 @@ def preference_loss_fn( - chosen_rewards: Reward signals for chosen responses (detached) - rejected_rewards: Reward signals for rejected responses (detached) """ - if ref_chosen_logps is None: - ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) - if ref_rejected_logps is None: - ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) + logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk + multiplier_chunk = torch.where(preference_labels_chunk, 1, -1) + if kl is not None: + losses = -(1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)) + else: + losses = -(1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)) - chosen_logratios = chosen_logps - ref_chosen_logps - rejected_logratios = rejected_logps - ref_rejected_logps - - if policy_KL_logps is None: - policy_KL_logps = torch.tensor(0.0, device=chosen_logps.device) - if ref_KL_logps is None: - ref_KL_logps = torch.tensor(0.0, device=chosen_logps.device) - - kl = (policy_KL_logps - ref_KL_logps).mean().clamp(min=0).detach() - losses = [] - - if chosen_logps.numel() > 0: - losses.append(1 - F.sigmoid(beta * (chosen_logratios - kl))) - if rejected_logps.numel() > 0: - losses.append(1 - F.sigmoid(beta * (kl - rejected_logratios))) - - losses = torch.cat(losses, dim=0) - - chosen_rewards = beta * chosen_logratios.detach() - rejected_rewards = beta * rejected_logratios.detach() - - return ( - # We don't divide by 2 because KTO Loss doesn't need pairwise examples - losses.sum() / (full_target.shape[0]), - chosen_rewards, - rejected_rewards, - ) + return losses.sum() / (full_target.shape[0]) @staticmethod def forward( @@ -102,23 +76,23 @@ def forward( _input, weight, target, + preference_labels, bias=None, - preference_labels=None, ref_input=None, ref_weight=None, ref_bias=None, + kl=None, ignore_index=-100, beta=0.1, compiled=True, use_ref_model=True, - policy_KL_logps=None, - ref_KL_logps=None, ): return LigerFusedLinearUnpairedPreferenceBase.forward( ctx=ctx, _input=_input, weight=weight, target=target, + preference_labels=preference_labels, bias=bias, loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn, ignore_index=ignore_index, @@ -128,14 +102,12 @@ def forward( ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, - preference_labels=preference_labels, - policy_KL_logps=policy_KL_logps, - ref_KL_logps=ref_KL_logps, + kl=kl, ) @staticmethod def backward(ctx, *grad_output): - grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:4] + grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5] return ( *grads, None, @@ -162,8 +134,6 @@ def __init__( beta: float = 0.1, compiled: bool = True, use_ref_model: bool = False, - policy_KL_logps: torch.FloatTensor = None, - ref_KL_logps: torch.FloatTensor = None, ): """ Args: @@ -179,8 +149,6 @@ def __init__( self.beta = beta self.compiled = compiled self.use_ref_model = use_ref_model - self.policy_KL_logps = policy_KL_logps - self.ref_KL_logps = ref_KL_logps def forward( self, @@ -192,20 +160,20 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + kl=None, ): return LigerFusedLinearKTOFunction.apply( _input, lin_weight, target, - bias, preference_labels, + bias, ref_input, ref_weight, ref_bias, + kl, self.ignore_index, self.beta, self.compiled, self.use_ref_model, - self.policy_KL_logps, - self.ref_KL_logps, ) diff --git a/test/chunked_loss/test_kto_loss.py b/test/chunked_loss/test_kto_loss.py index bcaac03ff..84def23df 100644 --- a/test/chunked_loss/test_kto_loss.py +++ b/test/chunked_loss/test_kto_loss.py @@ -27,22 +27,14 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True, - policy_KL_logps: torch.FloatTensor = None, - ref_KL_logps: torch.FloatTensor = None, ): super().__init__( beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model, - policy_KL_logps=policy_KL_logps, - ref_KL_logps=ref_KL_logps, unpaired=True, compute_nll_loss=False, ) - # KL logps need to be passed into the Loss class since it requires a full model forward pass - # See paper https://arxiv.org/abs/2402.01306 (4.1. Derivation) - self.policy_KL_logps = policy_KL_logps - self.ref_KL_logps = ref_KL_logps def alignment_loss( self, @@ -50,6 +42,7 @@ def alignment_loss( policy_rejected_logps: torch.FloatTensor, ref_chosen_logps: torch.FloatTensor, ref_rejected_logps: torch.FloatTensor, + kl: torch.FloatTensor = None, ): """Compute KTO loss for a batch of policy log probabilities. Args: @@ -60,24 +53,34 @@ def alignment_loss( Returns: The losses tensor contains the KTO loss for each example in the batch. """ - if self.policy_KL_logps is None: - self.policy_KL_logps = torch.zeros(1).to(device) - - if self.ref_KL_logps is None: - self.ref_KL_logps = torch.zeros(1).to(device) - - kl = (self.policy_KL_logps - self.ref_KL_logps).mean().clamp(min=0).detach() - - chosen_logratios = policy_chosen_logps - ref_chosen_logps - chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) - chosen_rewards = self.beta * chosen_logratios.detach() - - rejected_logratios = policy_rejected_logps - ref_rejected_logps - rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) - rejected_rewards = self.beta * rejected_logratios.detach() + if kl is None: + kl = torch.zeros(1).to(policy_chosen_logps.device) + + + # Chosen losses + if policy_chosen_logps.shape[0] != 0 or ref_chosen_logps.shape[0] != 0: + chosen_logratios = policy_chosen_logps - ref_chosen_logps + # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + chosen_losses = torch.Tensor([]).to(policy_chosen_logps.device) + + # Rejected losses + if policy_rejected_logps.shape[0] != 0 or ref_rejected_logps.shape[0] != 0: + rejected_logratios = policy_rejected_logps - ref_rejected_logps + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + rejected_losses = torch.Tensor([]).to(policy_rejected_logps.device) + + losses = torch.cat( + (chosen_losses, rejected_losses), + 0, + ) - losses = torch.cat((chosen_losses, rejected_losses), 0) - return losses, chosen_rewards, rejected_rewards + return losses class TorchLMHeadKTO(torch.nn.Module): @@ -90,8 +93,6 @@ def __init__( ref_bias: bool = False, ignore_index: int = -100, beta: float = 0.1, - policy_KL_logps: torch.FloatTensor = None, - ref_KL_logps: torch.FloatTensor = None, ): super().__init__() self.lin = torch.nn.Linear( @@ -104,11 +105,9 @@ def __init__( ignore_index=ignore_index, beta=beta, use_ref_model=True, - policy_KL_logps=policy_KL_logps, - ref_KL_logps=ref_KL_logps, ).get_batch_loss_metrics - def forward(self, x, ref_x, y, preference_labels): + def forward(self, x, ref_x, y, preference_labels, kl=None): return self.KTO_loss( weight=self.lin.weight, _input=x, @@ -118,6 +117,7 @@ def forward(self, x, ref_x, y, preference_labels): ref_weight=self.ref_lin.weight, ref_bias=self.ref_lin.bias, preference_labels=preference_labels, + kl=kl, ) @@ -131,8 +131,6 @@ def __init__( ref_bias: bool = False, ignore_index: int = -100, beta: float = 0.1, - policy_KL_logps: torch.FloatTensor = None, - ref_KL_logps: torch.FloatTensor = None, ): super().__init__() self.lin = torch.nn.Linear( @@ -145,11 +143,9 @@ def __init__( ignore_index=ignore_index, beta=beta, use_ref_model=True, - policy_KL_logps=policy_KL_logps, - ref_KL_logps=ref_KL_logps, ) - def forward(self, x, ref_x, y, preference_labels): + def forward(self, x, ref_x, y, preference_labels, kl=None): return self.KTO_loss( _input=x, lin_weight=self.lin.weight, @@ -159,6 +155,7 @@ def forward(self, x, ref_x, y, preference_labels): ref_input=ref_x, ref_weight=self.ref_lin.weight, ref_bias=self.ref_lin.bias, + kl=kl, ) @@ -185,7 +182,7 @@ def test_correctness( # Preference labels shape: [B] # Create binary preference labels (0 or 1) for each sequence in the batch # Used to indicate preferred sequences (1) vs non-preferred sequences (0) - preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device, requires_grad=False) torch_lm_head_KTO = TorchLMHeadKTO( H=H, @@ -245,21 +242,18 @@ def test_correctness( indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - loss1, aggregated_aux_outputs1 = torch_lm_head_KTO( + loss1 = torch_lm_head_KTO( x=input1, ref_x=ref_input, y=target, preference_labels=preference_labels ) - loss2, aggregated_aux_outputs2 = liger_lm_head_KTO( + loss2 = liger_lm_head_KTO( x=input2, ref_x=ref_input, y=target, preference_labels=preference_labels ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - loss1.backward() loss2.backward() - # Passed assert_verbose_allclose(input1, input2, atol=atol, rtol=rtol) assert_verbose_allclose( torch_lm_head_KTO.lin.weight, liger_lm_head_KTO.lin.weight, atol=atol, rtol=rtol @@ -343,22 +337,22 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None - loss1, aggregated_aux_outputs1 = LigerFusedLinearKTOFunction.apply( + loss1 = LigerFusedLinearKTOFunction.apply( input1, weight1, target, - bias1, preference_labels, + bias1, ref_input, ref_weight1, ref_bias1, ) - loss2, aggregated_aux_outputs2 = liger_fused_linear_kto( + loss2 = liger_fused_linear_kto( input2, weight2, target, - bias2, preference_labels, + bias2, ref_input, ref_weight2, ref_bias2, diff --git a/test/utils.py b/test/utils.py index 13e817129..153f20100 100644 --- a/test/utils.py +++ b/test/utils.py @@ -402,7 +402,7 @@ def get_batch_logps( def get_ref_logps( self, - _input: torch.FloatTensor, + ref_input: torch.FloatTensor, ref_weight: torch.FloatTensor, target: torch.LongTensor, ref_bias: torch.FloatTensor, @@ -411,23 +411,24 @@ def get_ref_logps( ): """Compute the log probabilities of the given labels under the given reference model.""" - ref_logits = _input @ ref_weight.t() - if ref_bias is not None: - ref_logits = ref_logits + ref_bias - ref_all_logps = self.get_batch_logps( - ref_logits, target, average_log_prob=average_log_prob - ) - - if self.unpaired and preference_labels is not None: - # Split based on preference labels - return ref_all_logps[preference_labels], ref_all_logps[~preference_labels] - else: - # Original paired behavior - split in half - return ( - ref_all_logps[: _input.shape[0] // 2], - ref_all_logps[_input.shape[0] // 2 :], + with torch.no_grad(): + ref_logits = ref_input @ ref_weight.t() + if ref_bias is not None: + ref_logits = ref_logits + ref_bias + ref_all_logps = self.get_batch_logps( + ref_logits, target, average_log_prob=average_log_prob ) + if self.unpaired and preference_labels is not None: + # Split based on preference labels + return ref_all_logps[preference_labels], ref_all_logps[~preference_labels] + else: + # Original paired behavior - split in half + return ( + ref_all_logps[: ref_input.shape[0] // 2], + ref_all_logps[ref_input.shape[0] // 2 :], + ) + def concatenated_forward( self, _input: torch.FloatTensor, @@ -506,9 +507,9 @@ def get_batch_loss_metrics( ref_bias: torch.FloatTensor = None, average_log_prob: bool = True, preference_labels: torch.Tensor = None, + **loss_kwargs, ): """Compute the loss metrics for the given batch of inputs for train or test.""" - forward_output = self.concatenated_forward( _input, weight, target, bias, average_log_prob, preference_labels ) @@ -520,7 +521,6 @@ def get_batch_loss_metrics( policy_nll_loss, ) = forward_output[:5] - loss_kwargs = {} if self.use_ref_model: ref_chosen_logps, ref_rejected_logps = self.get_ref_logps( ref_input, @@ -540,20 +540,20 @@ def get_batch_loss_metrics( else: losses, aggregated_aux_outputs = alignment_loss_outputs, [] - if self.compute_nll_loss: - # full loss - loss = policy_nll_loss * self.alpha - losses.mean() + + loss = policy_nll_loss * self.alpha - losses.mean() + + if not self.unpaired: + return_vars = ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits.detach().mean(), + policy_rejected_logits.detach().mean(), + policy_nll_loss, + ) + return loss, (*return_vars, *aggregated_aux_outputs) else: - # only alignment loss - loss = -losses.mean() - return_vars = ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits.detach().mean(), - policy_rejected_logits.detach().mean(), - policy_nll_loss, - ) - return loss, (*return_vars, *aggregated_aux_outputs) + return loss class HFDistillationLoss: