Skip to content

Commit 1f640a5

Browse files
authored
Trying out logprobs and top logprobs for testing rather than logits. (#745)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Just testing out logprobs as mentioned in #742 It worked for the models where the test using logits was not working. Also, tried to setup 1e-1 tolerance for qwen (previously 1) and it passed. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 0bd8292 commit 1f640a5

File tree

8 files changed

+121
-80
lines changed

8 files changed

+121
-80
lines changed

dev/modal/benchmarks.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
@app.function(gpu="H100", image=repo, timeout=60 * 45)
1818
def liger_benchmarks():
19-
import subprocess
2019
import os
20+
import subprocess
2121

2222
subprocess.run(
2323
["uv pip install -e '.[dev]' --system"],
@@ -30,7 +30,7 @@ def liger_benchmarks():
3030
file_path = Path(REMOTE_ROOT_PATH) / "benchmark" / "data" / "all_benchmark_data.csv"
3131
print(f"Checking if file exists at: {file_path}")
3232
print(f"File exists: {os.path.exists(file_path)}")
33-
33+
3434
if not os.path.exists(file_path):
3535
print("Listing directory contents:")
3636
data_dir = file_path.parent
@@ -53,21 +53,21 @@ def main():
5353
# Run the benchmarks and get the data
5454
print("Starting benchmark run...")
5555
benchmark_data = liger_benchmarks.remote()
56-
56+
5757
if not benchmark_data:
5858
raise ValueError("No data received from remote function")
59-
59+
6060
# Save the data locally
6161
local_data_path = ROOT_PATH / "benchmark" / "data" / "all_benchmark_data.csv"
6262
print(f"Attempting to save data to: {local_data_path}")
63-
63+
6464
local_data_path.parent.mkdir(parents=True, exist_ok=True)
65-
65+
6666
with open(local_data_path, "wb") as f:
6767
f.write(benchmark_data)
68-
68+
6969
print(f"Successfully saved {len(benchmark_data)} bytes to: {local_data_path}")
70-
70+
7171
except Exception as e:
7272
print(f"Error occurred: {str(e)}")
7373
raise

test/convergence/bf16/test_mini_models.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from test.utils import DEFAULT_DATASET_PATH
3939
from test.utils import MiniModelConfig
4040
from test.utils import assert_verbose_allclose
41+
from test.utils import get_logprobs
42+
from test.utils import get_topk
4143
from test.utils import revert_liger_kernel_to_gemma
4244
from test.utils import revert_liger_kernel_to_gemma2
4345
from test.utils import revert_liger_kernel_to_gemma3_text
@@ -851,17 +853,17 @@ def run_mini_model(
851853
eval_output = model(**eval_batch)
852854
print(f"Eval Loss: {eval_output.loss.item()}")
853855
loss_list.append(eval_output.loss.item())
854-
856+
topk_logprobs = get_topk(get_logprobs(eval_output.logits))
855857
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
856858
return {
857859
"loss": loss_list,
858-
"logits": eval_output.logits,
860+
"topk_logprobs": topk_logprobs.values,
859861
"model": model,
860862
}
861863

862864

863865
@pytest.mark.parametrize(
864-
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
866+
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
865867
[
866868
pytest.param(
867869
"mini_llama3",
@@ -884,7 +886,7 @@ def run_mini_model(
884886
1e-3,
885887
1e-2,
886888
1e-1,
887-
1e-2,
889+
1e-1,
888890
1e-2,
889891
1e-2,
890892
marks=[
@@ -902,7 +904,7 @@ def run_mini_model(
902904
torch.bfloat16,
903905
1e-3,
904906
1e-2,
905-
1, # 1e-1
907+
1e-1, # 1e-1
906908
1e-1, # 1e-2
907909
1e-2,
908910
1e-2,
@@ -972,7 +974,7 @@ def run_mini_model(
972974
torch.bfloat16,
973975
1e-3,
974976
1e-2,
975-
1, # 1e-1
977+
1e-1, # 1e-1
976978
1e-1, # 1e-2
977979
1e-2,
978980
1e-2,
@@ -1111,8 +1113,8 @@ def run_mini_model(
11111113
torch.bfloat16,
11121114
1e-3,
11131115
1e-2,
1114-
1e-1,
11151116
1e-2,
1117+
1e-1,
11161118
1e-2,
11171119
1e-2,
11181120
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
@@ -1124,8 +1126,8 @@ def run_mini_model(
11241126
torch.bfloat16,
11251127
1e-3,
11261128
1e-2,
1127-
1e-1,
11281129
1e-2,
1130+
1e-1,
11291131
1e-2,
11301132
1e-2,
11311133
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
@@ -1153,8 +1155,8 @@ def run_mini_model(
11531155
torch.bfloat16,
11541156
1e-3,
11551157
1e-2,
1156-
1e-1,
1157-
1e-2,
1158+
3e-1,
1159+
4e-1,
11581160
1e-2,
11591161
1e-2,
11601162
marks=[
@@ -1174,8 +1176,8 @@ def test_mini_model(
11741176
dtype,
11751177
loss_atol,
11761178
loss_rtol,
1177-
logits_atol,
1178-
logits_rtol,
1179+
logprobs_atol,
1180+
logprobs_rtol,
11791181
param_atol,
11801182
param_rtol,
11811183
):
@@ -1193,13 +1195,13 @@ def test_mini_model(
11931195
rtol=loss_rtol,
11941196
)
11951197

1196-
# Compare the logits from evaluation step
1197-
if expected_output["logits"] is not None and actual_output["logits"] is not None:
1198+
# Compare the topk logprobs from evaluation step
1199+
if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None:
11981200
assert_verbose_allclose(
1199-
expected_output["logits"],
1200-
actual_output["logits"],
1201-
atol=logits_atol,
1202-
rtol=logits_rtol,
1201+
expected_output["topk_logprobs"],
1202+
actual_output["topk_logprobs"],
1203+
atol=logprobs_atol,
1204+
rtol=logprobs_rtol,
12031205
)
12041206

12051207
# Compare the params from the last step

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from test.utils import UNTOKENIZED_DATASET_PATH
2121
from test.utils import MiniModelConfig
2222
from test.utils import assert_verbose_allclose
23+
from test.utils import get_logprobs
24+
from test.utils import get_topk
2325
from test.utils import is_torchvision_available
2426
from test.utils import load_image_processing_config
2527
from test.utils import load_processor_config
@@ -764,13 +766,17 @@ def run_mini_model_multimodal(
764766

765767
print(f"Step {i}, Loss: {output.loss.item()}")
766768
loss_list.append(output.loss.item())
767-
769+
topk_logprobs = get_topk(get_logprobs(output.logits))
768770
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
769-
return {"loss": loss_list, "logits": output.logits, "model": model}
771+
return {
772+
"loss": loss_list,
773+
"topk_logprobs": topk_logprobs.values,
774+
"model": model,
775+
}
770776

771777

772778
@pytest.mark.parametrize(
773-
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
779+
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
774780
[
775781
pytest.param(
776782
"mini_qwen2_vl",
@@ -917,8 +923,8 @@ def test_mini_model_multimodal(
917923
dtype,
918924
loss_atol,
919925
loss_rtol,
920-
logits_atol,
921-
logits_rtol,
926+
logprobs_atol,
927+
logprobs_rtol,
922928
param_atol,
923929
param_rtol,
924930
):
@@ -937,12 +943,12 @@ def test_mini_model_multimodal(
937943
rtol=loss_rtol,
938944
)
939945

940-
# Compare the logits from the last step
946+
# Compare the topk logprobs from evaluation step
941947
assert_verbose_allclose(
942-
expected_output["logits"],
943-
actual_output["logits"],
944-
atol=logits_atol,
945-
rtol=logits_rtol,
948+
expected_output["topk_logprobs"],
949+
actual_output["topk_logprobs"],
950+
atol=logprobs_atol,
951+
rtol=logprobs_rtol,
946952
)
947953

948954
# Compare the params from the last step

test/convergence/bf16/test_mini_models_with_logits.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from test.utils import DEFAULT_DATASET_PATH
3939
from test.utils import MiniModelConfig
4040
from test.utils import assert_verbose_allclose
41+
from test.utils import get_logprobs
42+
from test.utils import get_topk
4143
from test.utils import revert_liger_kernel_to_gemma
4244
from test.utils import revert_liger_kernel_to_gemma2
4345
from test.utils import revert_liger_kernel_to_gemma3_text
@@ -842,12 +844,17 @@ def run_mini_model(
842844
print(f"Step {i}, Loss: {output.loss.item()}")
843845
loss_list.append(output.loss.item())
844846

847+
topk_logprobs = get_topk(get_logprobs(output.logits))
845848
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
846-
return {"loss": loss_list, "logits": output.logits, "model": model}
849+
return {
850+
"loss": loss_list,
851+
"topk_logprobs": topk_logprobs.values,
852+
"model": model,
853+
}
847854

848855

849856
@pytest.mark.parametrize(
850-
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
857+
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
851858
[
852859
pytest.param(
853860
"mini_llama3",
@@ -1058,8 +1065,8 @@ def run_mini_model(
10581065
torch.bfloat16,
10591066
1e-3,
10601067
1e-2,
1061-
1e-1,
10621068
1e-2,
1069+
1e-1,
10631070
1e-2,
10641071
1e-2,
10651072
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
@@ -1071,8 +1078,8 @@ def run_mini_model(
10711078
torch.bfloat16,
10721079
1e-3,
10731080
1e-2,
1074-
1e-1,
10751081
1e-2,
1082+
1e-1,
10761083
1e-2,
10771084
1e-2,
10781085
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
@@ -1159,8 +1166,8 @@ def test_mini_model(
11591166
dtype,
11601167
loss_atol,
11611168
loss_rtol,
1162-
logits_atol,
1163-
logits_rtol,
1169+
logprobs_atol,
1170+
logprobs_rtol,
11641171
param_atol,
11651172
param_rtol,
11661173
):
@@ -1180,12 +1187,12 @@ def test_mini_model(
11801187

11811188
# No logits are materialized
11821189
# import pdb; pdb.set_trace()
1183-
# Compare the logits from the last step
1190+
# Compare the topk logprobs from evaluation step
11841191
assert_verbose_allclose(
1185-
expected_output["logits"],
1186-
actual_output["logits"],
1187-
atol=logits_atol,
1188-
rtol=logits_rtol,
1192+
expected_output["topk_logprobs"],
1193+
actual_output["topk_logprobs"],
1194+
atol=logprobs_atol,
1195+
rtol=logprobs_rtol,
11891196
)
11901197

11911198
# Compare the params from the last step

test/convergence/fp32/test_mini_models.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from test.utils import DEFAULT_DATASET_PATH
3939
from test.utils import MiniModelConfig
4040
from test.utils import assert_verbose_allclose
41+
from test.utils import get_logprobs
42+
from test.utils import get_topk
4143
from test.utils import revert_liger_kernel_to_gemma
4244
from test.utils import revert_liger_kernel_to_gemma2
4345
from test.utils import revert_liger_kernel_to_gemma3_text
@@ -849,17 +851,17 @@ def run_mini_model(
849851
eval_output = model(**eval_batch)
850852
print(f"Eval Loss: {eval_output.loss.item()}")
851853
loss_list.append(eval_output.loss.item())
852-
854+
topk_logprobs = get_topk(get_logprobs(eval_output.logits))
853855
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
854856
return {
855857
"loss": loss_list,
856-
"logits": eval_output.logits,
858+
"topk_logprobs": topk_logprobs.values,
857859
"model": model,
858860
}
859861

860862

861863
@pytest.mark.parametrize(
862-
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
864+
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
863865
[
864866
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
865867
pytest.param(
@@ -1013,7 +1015,7 @@ def run_mini_model(
10131015
# TODO: mixtral is flaky so disable the test for now
10141016
# ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
10151017
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
1016-
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
1018+
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-2, 5e-3, 1e-5),
10171019
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
10181020
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
10191021
pytest.param(
@@ -1041,8 +1043,8 @@ def test_mini_model(
10411043
dtype,
10421044
loss_atol,
10431045
loss_rtol,
1044-
logits_atol,
1045-
logits_rtol,
1046+
logprobs_atol,
1047+
logprobs_rtol,
10461048
param_atol,
10471049
param_rtol,
10481050
):
@@ -1060,13 +1062,13 @@ def test_mini_model(
10601062
rtol=loss_rtol,
10611063
)
10621064

1063-
# Compare the logits from evaluation step
1064-
if expected_output["logits"] is not None and actual_output["logits"] is not None:
1065+
# Compare the topk logprobs from evaluation step
1066+
if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None:
10651067
assert_verbose_allclose(
1066-
expected_output["logits"],
1067-
actual_output["logits"],
1068-
atol=logits_atol,
1069-
rtol=logits_rtol,
1068+
expected_output["topk_logprobs"],
1069+
actual_output["topk_logprobs"],
1070+
atol=logprobs_atol,
1071+
rtol=logprobs_rtol,
10701072
)
10711073

10721074
# Compare the params from the last step

0 commit comments

Comments
 (0)