-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add StableMax integration to enable grokking and prevent Softmax Collapse #2761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughA new integration named StableMax has been added to the project. This includes documentation, plugin code, argument definitions, and the core StableMax activation and loss function implementation. The Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant AxolotlConfig
participant StableMaxPlugin
participant PyTorch
User->>AxolotlConfig: Enable stablemax in config
AxolotlConfig->>StableMaxPlugin: pre_model_load(cfg)
StableMaxPlugin->>StableMaxPlugin: Check cfg.stablemax
alt If stablemax enabled
StableMaxPlugin->>PyTorch: Patch cross_entropy with stablemax_cross_entropy
end
User->>PyTorch: Train/evaluate model (calls cross_entropy)
PyTorch->>StableMaxPlugin: Use stablemax_cross_entropy
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (8)
src/axolotl/integrations/stablemax/ACKNOWLEDGEMENTS.md (1)
1-1
: Populate acknowledgements or remove file
This file is a placeholder—add any external credits or acknowledgements relevant to StableMax, or remove it if none are needed.docs/custom_integrations.qmd (1)
53-55
: Ensure integration name consistency
The display entry uses"Stablemax"
but elsewhere the plugin is namedStableMax
. Update the section name to"StableMax"
to maintain consistent casing across docs.src/axolotl/integrations/stablemax/__init__.py (1)
3-3
: Remove unused import.The
torch
import is not used in this file.-import torch
🧰 Tools
🪛 Ruff (0.11.9)
3-3:
torch
imported but unused(F401)
src/axolotl/integrations/stablemax/README.md (2)
16-18
: Add language specification to code block.The fenced code block should specify the language for better rendering and syntax highlighting.
-``` +```text StableMax(x_i) = s(x_i) / sum_j s(x_j)<details> <summary>🧰 Tools</summary> <details> <summary>🪛 markdownlint-cli2 (0.17.2)</summary> 16-16: Fenced code blocks should have a language specified null (MD040, fenced-code-language) </details> </details> --- `19-19`: **Consider hyphenating compound adjective.** For better readability, consider hyphenating "floating-point" when used as a compound adjective. ```diff -This prevents floating point absorption errors that can halt learning in grokking tasks. +This prevents floating-point absorption errors that can halt learning in grokking tasks.
🧰 Tools
🪛 LanguageTool
[uncategorized] ~19-~19: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...(EN_COMPOUND_ADJECTIVE_INTERNAL)
src/axolotl/integrations/stablemax/stablemax.py (3)
2-2
: Remove unused import.The
torch.nn.functional
import is not used in this file.-import torch.nn.functional as F
🧰 Tools
🪛 Ruff (0.11.9)
2-2:
torch.nn.functional
imported but unusedRemove unused import:
torch.nn.functional
(F401)
32-37
: Simplify control flow by removing unnecessary elif.After a return statement, the
elif
is unnecessary and can be simplified toif
.if reduction == "mean": return loss.mean() - elif reduction == "sum": + if reduction == "sum": return loss.sum() - else: - return loss + return loss🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 32-37: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
26-26
: Consider making epsilon configurable or using a larger value.The hardcoded epsilon
1e-12
might be too small for some numerical scenarios, especially with mixed precision training. Consider using a slightly larger value like1e-8
or making it configurable.- log_probs = torch.log(probs + 1e-12) + log_probs = torch.log(probs + 1e-8)Also applies to: 30-30
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
.gitignore
(1 hunks)docs/custom_integrations.qmd
(1 hunks)src/axolotl/integrations/stablemax/ACKNOWLEDGEMENTS.md
(1 hunks)src/axolotl/integrations/stablemax/LICENSE
(1 hunks)src/axolotl/integrations/stablemax/README.md
(1 hunks)src/axolotl/integrations/stablemax/__init__.py
(1 hunks)src/axolotl/integrations/stablemax/args.py
(1 hunks)src/axolotl/integrations/stablemax/stablemax.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/integrations/stablemax/stablemax.py (1)
src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
loss
(56-158)
🪛 Ruff (0.11.9)
src/axolotl/integrations/stablemax/__init__.py
3-3: torch
imported but unused
(F401)
src/axolotl/integrations/stablemax/stablemax.py
2-2: torch.nn.functional
imported but unused
Remove unused import: torch.nn.functional
(F401)
🪛 LanguageTool
src/axolotl/integrations/stablemax/README.md
[uncategorized] ~19-~19: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
🪛 markdownlint-cli2 (0.17.2)
src/axolotl/integrations/stablemax/README.md
4-4: Blank line inside blockquote
null
(MD028, no-blanks-blockquote)
16-16: Fenced code blocks should have a language specified
null
(MD040, fenced-code-language)
🪛 Pylint (3.3.7)
src/axolotl/integrations/stablemax/stablemax.py
[refactor] 32-37: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
⏰ Context from checks skipped due to timeout of 90000ms (8)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: preview
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (4)
.gitignore (1)
194-196
: macOS resource fork ignore rules
The patterns for._*
and.DS_Store
are correctly added to prevent macOS artifacts from being committed.src/axolotl/integrations/stablemax/args.py (1)
1-10
: Argument model is well-defined
StableMaxArgs
correctly exposes astablemax
flag with a clear description, fitting the plugin’s configuration pattern.src/axolotl/integrations/stablemax/README.md (1)
1-40
: Excellent documentation quality.The documentation provides clear explanations of the StableMax function, its mathematical foundation, usage instructions, and important compatibility notes. The references to the ICLR 2025 paper add credibility and context.
🧰 Tools
🪛 LanguageTool
[uncategorized] ~19-~19: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...(EN_COMPOUND_ADJECTIVE_INTERNAL)
🪛 markdownlint-cli2 (0.17.2)
4-4: Blank line inside blockquote
null(MD028, no-blanks-blockquote)
16-16: Fenced code blocks should have a language specified
null(MD040, fenced-code-language)
src/axolotl/integrations/stablemax/stablemax.py (1)
4-11
: 🛠️ Refactor suggestionAddress potential numerical instability in the piecewise function.
The current implementation has a potential numerical issue: when
x
approaches 1 from below,1 / (1 - x)
can become extremely large, and whenx = 1
exactly, this results in division by zero.Consider adding a small epsilon or clamping to prevent numerical instability:
def stablemax_fn(x): """ Numerically stable alternative to softmax. s(x) = x + 1 if x >= 0, else 1 / (1 - x) StableMax(x_i) = s(x_i) / sum_j s(x_j) """ - s = torch.where(x >= 0, x + 1, 1 / (1 - x)) + # Clamp x to prevent division by zero when x approaches 1 + x_clamped = torch.clamp(x, max=0.999) + s = torch.where(x >= 0, x + 1, 1 / (1 - x_clamped)) return s / s.sum(dim=-1, keepdim=True)Likely an incorrect or invalid review comment.
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (9)
src/axolotl/integrations/stablemax/__init__.py (2)
3-3
: Remove unused import.The
torch
import is not used in this file and should be removed to fix the linting error.-import torch
🧰 Tools
🪛 Ruff (0.11.9)
3-3:
torch
imported but unused(F401)
🪛 GitHub Actions: lint
[error] 3-3: flake8: 'torch' imported but unused (F401)
1-1
: Add module docstring.Add a module-level docstring to address the pylint warning and provide clear documentation about this plugin's purpose.
+""" +StableMax integration plugin for Axolotl. + +This plugin provides a numerically stable alternative to softmax cross-entropy +by globally patching torch.nn.functional.cross_entropy when enabled. +""" + # StableMax integration entry point🧰 Tools
🪛 GitHub Actions: lint
[error] 1-1: pylint: Missing module docstring (missing-module-docstring)
[error] 1-1: pre-commit hook 'trailing-whitespace' failed and fixed trailing whitespace issues
[error] 1-1: pre-commit hook 'isort' fixed import sorting issues
src/axolotl/integrations/stablemax/README.md (3)
21-21
: Use hyphenated compound adjective.For better readability, hyphenate the compound adjective "floating-point" when it modifies a noun.
-This prevents floating point absorption errors that can halt learning in grokking tasks. +This prevents floating-point absorption errors that can halt learning in grokking tasks.🧰 Tools
🪛 LanguageTool
[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...(EN_COMPOUND_ADJECTIVE_INTERNAL)
18-20
: Specify language for code block.Add a language identifier to the fenced code block for better syntax highlighting.
-``` +```text StableMax(x_i) = s(x_i) / sum_j s(x_j)<details> <summary>🧰 Tools</summary> <details> <summary>🪛 markdownlint-cli2 (0.17.2)</summary> 18-18: Fenced code blocks should have a language specified null (MD040, fenced-code-language) </details> </details> --- `3-6`: **Fix blockquote formatting.** Remove blank lines inside blockquotes to fix markdownlint warnings. ```diff > **⚠️ WARNING:** StableMax performs **global patching** of `torch.nn.functional.cross_entropy`, replacing it with `stablemax_cross_entropy` for ALL subsequent calls throughout the entire application. This affects not only your model training but also any other libraries, models, or code that use `torch.nn.functional.cross_entropy`. - > **⚠️ COMPATIBILITY:** Do not enable StableMax simultaneously with other cross-entropy patches such as **Liger** (`liger_cross_entropy`, `liger_fused_linear_cross_entropy`) or **CutCrossEntropy** (`cut_cross_entropy`). The system will detect and prevent such conflicts, but enabling multiple patches can lead to unpredictable runtime behavior. - > **Note:** StableMax is intended to be used in combination with the orthograd optimizer ([implementation here](https://github.com/cognitivecomputations/dolphinflow-optimizer)) to fully implement the solution described in Prieto et al.
🧰 Tools
🪛 markdownlint-cli2 (0.17.2)
4-4: Blank line inside blockquote
null(MD028, no-blanks-blockquote)
6-6: Blank line inside blockquote
null(MD028, no-blanks-blockquote)
src/axolotl/integrations/stablemax/stablemax.py (4)
1-1
: Add module docstring.Add a module-level docstring to describe the StableMax implementation and its purpose.
+""" +StableMax implementation: A numerically stable alternative to softmax. + +This module implements the StableMax function and its corresponding cross-entropy +loss as described in "Grokking at the Edge of Numerical Stability" (Prieto et al., ICLR 2025). +""" + import torch🧰 Tools
🪛 GitHub Actions: lint
[error] 1-1: pylint: Missing module docstring (missing-module-docstring)
13-15
: Rename parameter to avoid shadowing built-in.The parameter
input
shadows Python's built-ininput()
function. Consider using a more descriptive name.-def stablemax_cross_entropy(input, target, weight=None, ignore_index=-100, +def stablemax_cross_entropy(logits, target, weight=None, ignore_index=-100, size_average=None, reduce=None, reduction="mean", label_smoothing=0.0):And update the corresponding usage:
- probs = stablemax_fn(input) + probs = stablemax_fn(logits)
23-24
: Document unused legacy parameters.The
size_average
andreduce
parameters are kept for PyTorch compatibility but not used. Update the docstring to clarify this.- size_average: deprecated (kept for compatibility) - reduce: deprecated (kept for compatibility) + size_average: deprecated, ignored (kept for PyTorch compatibility) + reduce: deprecated, ignored (kept for PyTorch compatibility)
82-87
: Simplify conditional structure.Remove unnecessary
elif
statements afterreturn
to improve code readability.elif reduction == "mean": return loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=input.device) - elif reduction == "sum": + if reduction == "sum": return loss.sum() - else: - raise ValueError(f"Invalid reduction mode: {reduction}") + + raise ValueError(f"Invalid reduction mode: {reduction}")🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 75-87: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
🪛 GitHub Actions: lint
[error] 86-86: pylint: Unnecessary 'elif' after 'return', remove the leading 'el' from 'elif' (no-else-return)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/integrations/stablemax/LICENSE
(1 hunks)src/axolotl/integrations/stablemax/README.md
(1 hunks)src/axolotl/integrations/stablemax/__init__.py
(1 hunks)src/axolotl/integrations/stablemax/stablemax.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/integrations/stablemax/LICENSE
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/integrations/stablemax/stablemax.py (1)
src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
loss
(56-158)
src/axolotl/integrations/stablemax/__init__.py (3)
src/axolotl/integrations/base.py (1)
BasePlugin
(41-238)src/axolotl/integrations/stablemax/stablemax.py (1)
stablemax_cross_entropy
(13-87)src/axolotl/integrations/stablemax/args.py (1)
StableMaxArgs
(3-10)
🪛 Pylint (3.3.7)
src/axolotl/integrations/stablemax/stablemax.py
[refactor] 75-87: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
🪛 GitHub Actions: lint
src/axolotl/integrations/stablemax/stablemax.py
[error] 1-1: pylint: Missing module docstring (missing-module-docstring)
[error] 16-16: pylint: Redefining built-in 'input' (redefined-builtin)
[error] 86-86: pylint: Unnecessary 'elif' after 'return', remove the leading 'el' from 'elif' (no-else-return)
[error] 20-20: pylint: Unused argument 'size_average' (unused-argument)
[error] 21-21: pylint: Unused argument 'reduce' (unused-argument)
src/axolotl/integrations/stablemax/__init__.py
[error] 3-3: flake8: 'torch' imported but unused (F401)
[error] 1-1: pylint: Missing module docstring (missing-module-docstring)
[error] 1-1: pre-commit hook 'trailing-whitespace' failed and fixed trailing whitespace issues
[error] 1-1: pre-commit hook 'isort' fixed import sorting issues
🪛 Ruff (0.11.9)
src/axolotl/integrations/stablemax/__init__.py
3-3: torch
imported but unused
(F401)
🪛 LanguageTool
src/axolotl/integrations/stablemax/README.md
[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
🪛 markdownlint-cli2 (0.17.2)
src/axolotl/integrations/stablemax/README.md
4-4: Blank line inside blockquote
null
(MD028, no-blanks-blockquote)
6-6: Blank line inside blockquote
null
(MD028, no-blanks-blockquote)
18-18: Fenced code blocks should have a language specified
null
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (3)
src/axolotl/integrations/stablemax/__init__.py (1)
16-57
: LGTM! Excellent conflict detection and documentation.The implementation properly addresses the global patching concerns with:
- Clear warnings in the docstring about global replacement behavior
- Comprehensive conflict detection for other cross-entropy patches (Liger, CutCrossEntropy)
- Helpful error messages when conflicts are detected
This ensures users are aware of the implications and prevents runtime conflicts.
src/axolotl/integrations/stablemax/stablemax.py (2)
4-11
: LGTM! Correct StableMax implementation.The
stablemax_fn
correctly implements the piecewise function:
s(x) = x + 1
forx >= 0
s(x) = 1 / (1 - x)
forx < 0
- Proper normalization by sum along the last dimension
This matches the mathematical definition from the paper.
30-88
:❌ Incorrect review comment
Verify edge cases in cross-entropy implementation.
The implementation handles complex scenarios but some edge cases need verification:
- One-hot targets with ignore_index: Currently assumes ignore_index doesn't apply to one-hot targets, but this might not always be correct.
- Weight tensor shape assumptions: The code assumes
weight
is 1D in line 63 (weight.unsqueeze(0)
), but PyTorch's cross_entropy accepts weights that match the target shape.- Empty batch handling: When all targets are ignored,
loss.numel() > 0
check prevents division by zero, which is good.Please verify these edge cases work correctly:
🏁 Script executed:
#!/bin/bash # Search for how PyTorch handles these edge cases in cross_entropy rg -A 5 -B 5 "one_hot.*ignore_index|ignore_index.*one_hot" --type pyLength of output: 70
Ignore unnecessary edge-case checks for cross-entropy
The implementation aligns with PyTorch’s semantics:
- One-hot mode inherently doesn’t support
ignore_index
, so assuming no mask there is correct.weight
must be a 1D tensor of class-level weights (PyTorch’scross_entropy
only accepts per-class weights), soweight.unsqueeze(0)
is appropriate.- The empty-batch guard (
loss.numel() > 0
) correctly avoids division by zero.Likely an incorrect or invalid review comment.
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 75-87: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
🪛 GitHub Actions: lint
[error] 86-86: pylint: Unnecessary 'elif' after 'return', remove the leading 'el' from 'elif' (no-else-return)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
src/axolotl/integrations/stablemax/__init__.py (1)
3-3
: Remove unused import.The
torch
import is not used in this file and should be removed for cleaner code.-import torch
🧰 Tools
🪛 Ruff (0.11.9)
3-3:
torch
imported but unused(F401)
src/axolotl/integrations/stablemax/README.md (2)
3-6
: Fix markdown formatting in warning blockquotes.The blockquotes have formatting issues that affect readability.
-> **⚠️ WARNING:** StableMax performs **global patching** of `torch.nn.functional.cross_entropy`, replacing it with `stablemax_cross_entropy` for ALL subsequent calls throughout the entire application. This affects not only your model training but also any other libraries, models, or code that use `torch.nn.functional.cross_entropy`. - -> **⚠️ COMPATIBILITY:** Do not enable StableMax simultaneously with other cross-entropy patches such as **Liger** (`liger_cross_entropy`, `liger_fused_linear_cross_entropy`) or **CutCrossEntropy** (`cut_cross_entropy`). The system will detect and prevent such conflicts, but enabling multiple patches can lead to unpredictable runtime behavior. +> **⚠️ WARNING:** StableMax performs **global patching** of `torch.nn.functional.cross_entropy`, replacing it with `stablemax_cross_entropy` for ALL subsequent calls throughout the entire application. This affects not only your model training but also any other libraries, models, or code that use `torch.nn.functional.cross_entropy`. +> +> **⚠️ COMPATIBILITY:** Do not enable StableMax simultaneously with other cross-entropy patches such as **Liger** (`liger_cross_entropy`, `liger_fused_linear_cross_entropy`) or **CutCrossEntropy** (`cut_cross_entropy`). The system will detect and prevent such conflicts, but enabling multiple patches can lead to unpredictable runtime behavior.🧰 Tools
🪛 markdownlint-cli2 (0.17.2)
4-4: Blank line inside blockquote
null(MD028, no-blanks-blockquote)
6-6: Blank line inside blockquote
null(MD028, no-blanks-blockquote)
18-21
: Add language specification to code block and fix hyphenation.Minor formatting improvements for better documentation quality.
-``` +```text StableMax(x_i) = s(x_i) / sum_j s(x_j)-This prevents floating point absorption errors that can halt learning in grokking tasks.
+This prevents floating-point absorption errors that can halt learning in grokking tasks.<details> <summary>🧰 Tools</summary> <details> <summary>🪛 LanguageTool</summary> [uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen. Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin... (EN_COMPOUND_ADJECTIVE_INTERNAL) </details> <details> <summary>🪛 markdownlint-cli2 (0.17.2)</summary> 18-18: Fenced code blocks should have a language specified null (MD040, fenced-code-language) </details> </details> </blockquote></details> <details> <summary>src/axolotl/integrations/stablemax/stablemax.py (1)</summary><blockquote> `15-102`: **Comprehensive cross-entropy implementation with excellent PyTorch compatibility.** The function successfully addresses the previous review feedback by implementing: ✅ **Complete API compatibility** with `torch.nn.functional.cross_entropy` ✅ **Robust target handling** for both class indices and one-hot encoded targets ✅ **Proper ignore_index support** with valid masking ✅ **Label smoothing** implementation ✅ **Class weighting** for both target formats ✅ **All reduction modes** (`none`, `mean`, `sum`) with correct tensor sizing The edge case handling is particularly well thought out, such as maintaining original batch size for `reduction="none"` with ignored indices (lines 88-94). Minor style improvement for the elif chain: ```diff - elif reduction == "mean": + if reduction == "mean": return ( loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=input.device) ) - elif reduction == "sum": + if reduction == "sum": return loss.sum() - else: + raise ValueError(f"Invalid reduction mode: {reduction}")
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 86-102: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/integrations/stablemax/README.md
(1 hunks)src/axolotl/integrations/stablemax/__init__.py
(1 hunks)src/axolotl/integrations/stablemax/args.py
(1 hunks)src/axolotl/integrations/stablemax/stablemax.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/integrations/stablemax/args.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/integrations/stablemax/__init__.py (3)
src/axolotl/integrations/base.py (1)
BasePlugin
(41-238)src/axolotl/integrations/stablemax/args.py (1)
StableMaxArgs
(4-12)src/axolotl/integrations/stablemax/stablemax.py (1)
stablemax_cross_entropy
(15-102)
src/axolotl/integrations/stablemax/stablemax.py (1)
src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
loss
(56-158)
🪛 LanguageTool
src/axolotl/integrations/stablemax/README.md
[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
🪛 markdownlint-cli2 (0.17.2)
src/axolotl/integrations/stablemax/README.md
4-4: Blank line inside blockquote
null
(MD028, no-blanks-blockquote)
6-6: Blank line inside blockquote
null
(MD028, no-blanks-blockquote)
18-18: Fenced code blocks should have a language specified
null
(MD040, fenced-code-language)
🪛 Ruff (0.11.9)
src/axolotl/integrations/stablemax/__init__.py
3-3: torch
imported but unused
(F401)
🪛 Pylint (3.3.7)
src/axolotl/integrations/stablemax/stablemax.py
[refactor] 86-102: Unnecessary "elif" after "return", remove the leading "el" from "elif"
(R1705)
⏰ Context from checks skipped due to timeout of 90000ms (7)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (3)
src/axolotl/integrations/stablemax/__init__.py (1)
31-61
: Excellent conflict detection and global patching documentation.The implementation properly addresses the concerns raised in previous reviews:
- Clear documentation of global patching behavior with appropriate warnings
- Comprehensive conflict detection for Liger and CutCrossEntropy integrations
- Descriptive error messages to guide users
This provides a safe and well-documented approach to the inherently risky global patching strategy.
src/axolotl/integrations/stablemax/README.md (1)
1-42
: Comprehensive and well-structured documentation.The README provides excellent coverage of:
- Clear warnings about global patching behavior and compatibility issues
- Mathematical explanation of the StableMax function
- Proper usage instructions with configuration examples
- References to the academic paper and related resources
This documentation effectively addresses the integration's complexity and potential risks.
🧰 Tools
🪛 LanguageTool
[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...(EN_COMPOUND_ADJECTIVE_INTERNAL)
🪛 markdownlint-cli2 (0.17.2)
4-4: Blank line inside blockquote
null(MD028, no-blanks-blockquote)
6-6: Blank line inside blockquote
null(MD028, no-blanks-blockquote)
18-18: Fenced code blocks should have a language specified
null(MD040, fenced-code-language)
src/axolotl/integrations/stablemax/stablemax.py (1)
5-12
: StableMax function implementation is mathematically correct.The piecewise function correctly implements the paper's specification:
- For x ≥ 0: s(x) = x + 1
- For x < 0: s(x) = 1 / (1 - x)
- Proper normalization by sum
The
torch.where
approach is efficient and the normalization maintains probability distribution properties.
Can you run |
I a bit busy and my cat died. if you want stablemax here it is. |
btw I think this will address the problem of the model getting stuck repeating the same token over and over. |
also if you use stablemax you don't need to use weight_decay |
@ehartford I tried to lint/clean up this PR, but maintainers don't have access to modify this PR. Can you update access please? Unfortunately we can't merge this when we can't run basic CI against it. Thanks! |
sorry for the trouble. |
This PR adds a new integration for StableMax, a numerically stable alternative to softmax that prevents Softmax Collapse (SC) during training. StableMax enables grokking (sudden generalization after prolonged overfitting) in settings where it would otherwise be prevented by floating point errors.
Motivation and Context
This change implements the StableMax activation function from "Grokking at the Edge of Numerical Stability" (Prieto et al., ICLR 2025). The paper demonstrates that:
Softmax Collapse prevents grokking: When models achieve very low training loss, extreme logit values cause floating point absorption errors in the softmax function. This zeroes out gradients and halts learning before generalization can occur.
Naive Loss Minimization (NLM): After overfitting, gradients align with a direction that scales logits without changing predictions, eventually leading to numerical instability.
StableMax enables grokking without regularization: By replacing the exponential function with a gentler scaling function, StableMax maintains numerical stability and allows models to continue learning through the grokking phase.
This integration is particularly relevant for:
Paper: https://arxiv.org/abs/2501.04697
Note: This may also help with repetitive token generation as a side effect of improved numerical stability.
StableMax is intended to be used in combination with the orthograd optimizer (my implementation is available here: https://github.com/cognitivecomputations/dolphinflow-optimizer) - in order to fully implement the solution described in Prieto et al.
Summary by CodeRabbit
New Features
Documentation
Chores