Description
🐛 Describe the bug
#369 found that CrossEntropyLoss wasn't applied in post-grad-acc-fix versions of transformers. Despite the fact that #375 fixed the issue, it didn't consider the revert functions for convergence test.
Currently, the convergence test, test_mini_models_with_logits
, is comparing two models which both are using LigerCrossEntropyLoss except the first test case. In other words, the test results might be false positive in the second and later test cases
The implementation of current revert functions is reloading module by calling importlib.reload(module_name)
. We can fix the issue by carefully checking the transformers version and adding all patched modules for reloads. We should also enhance our monkey_patch unit test by adding another revert and compare, ensuring the correctness of convergence test results.
Reproduce
Add a print statement in LigerCrossEntropyLossFunction and run
python3 -m pytest test/convergence/test_mini_models_with_logits.py -v -rP
Versions
none