Skip to content

Add StableMax integration to enable grokking and prevent Softmax Collapse #2761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

ehartford
Copy link
Collaborator

@ehartford ehartford commented Jun 5, 2025

This PR adds a new integration for StableMax, a numerically stable alternative to softmax that prevents Softmax Collapse (SC) during training. StableMax enables grokking (sudden generalization after prolonged overfitting) in settings where it would otherwise be prevented by floating point errors.

Motivation and Context

This change implements the StableMax activation function from "Grokking at the Edge of Numerical Stability" (Prieto et al., ICLR 2025). The paper demonstrates that:

  1. Softmax Collapse prevents grokking: When models achieve very low training loss, extreme logit values cause floating point absorption errors in the softmax function. This zeroes out gradients and halts learning before generalization can occur.

  2. Naive Loss Minimization (NLM): After overfitting, gradients align with a direction that scales logits without changing predictions, eventually leading to numerical instability.

  3. StableMax enables grokking without regularization: By replacing the exponential function with a gentler scaling function, StableMax maintains numerical stability and allows models to continue learning through the grokking phase.

This integration is particularly relevant for:

  • Training on algorithmic tasks (modular arithmetic, sparse parity)
  • Scenarios where delayed generalization is observed
  • Cases where models overfit but fail to generalize without heavy regularization

Paper: https://arxiv.org/abs/2501.04697

Note: This may also help with repetitive token generation as a side effect of improved numerical stability.

StableMax is intended to be used in combination with the orthograd optimizer (my implementation is available here: https://github.com/cognitivecomputations/dolphinflow-optimizer) - in order to fully implement the solution described in Prieto et al.

Summary by CodeRabbit

  • New Features

    • Introduced StableMax integration as a numerically stable alternative to softmax for classification tasks.
    • Added configuration options to enable StableMax within the application.
    • Integrated StableMax cross-entropy loss, replacing standard softmax loss when enabled.
  • Documentation

    • Added README, LICENSE, and ACKNOWLEDGEMENTS for the StableMax integration.
    • Updated integration documentation to include StableMax.
  • Chores

    • Updated .gitignore to exclude macOS system files.

Copy link

coderabbitai bot commented Jun 5, 2025

Walkthrough

A new integration named StableMax has been added to the project. This includes documentation, plugin code, argument definitions, and the core StableMax activation and loss function implementation. The .gitignore was updated to exclude macOS system files. Documentation and configuration were updated to reference the new StableMax integration.

Changes

File(s) Change Summary
.gitignore Added ignore patterns for macOS system files (._*, .DS_Store).
docs/custom_integrations.qmd Added "Stablemax" to the list of integration sections.
src/axolotl/integrations/stablemax/ACKNOWLEDGEMENTS.md,
LICENSE
Added placeholder acknowledgements and MIT license files for the StableMax integration directory.
src/axolotl/integrations/stablemax/README.md Added documentation introducing StableMax, its usage, and references.
src/axolotl/integrations/stablemax/init.py Added StableMax integration plugin class with configuration checks and patching of PyTorch cross-entropy function.
src/axolotl/integrations/stablemax/args.py Introduced StableMaxArgs Pydantic model for integration configuration.
src/axolotl/integrations/stablemax/stablemax.py Implemented stablemax_fn and stablemax_cross_entropy functions for numerically stable activation and loss.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant AxolotlConfig
    participant StableMaxPlugin
    participant PyTorch

    User->>AxolotlConfig: Enable stablemax in config
    AxolotlConfig->>StableMaxPlugin: pre_model_load(cfg)
    StableMaxPlugin->>StableMaxPlugin: Check cfg.stablemax
    alt If stablemax enabled
        StableMaxPlugin->>PyTorch: Patch cross_entropy with stablemax_cross_entropy
    end
    User->>PyTorch: Train/evaluate model (calls cross_entropy)
    PyTorch->>StableMaxPlugin: Use stablemax_cross_entropy
Loading

Poem

In a warren of code, new functions arise,
StableMax hops in, with numerically wise ties.
Softmax collapse, now out of sight,
Mac files ignored, the repo feels light.
With plugins and docs, the future looks bright—
Rabbits rejoice, for the math is just right!
🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (8)
src/axolotl/integrations/stablemax/ACKNOWLEDGEMENTS.md (1)

1-1: Populate acknowledgements or remove file
This file is a placeholder—add any external credits or acknowledgements relevant to StableMax, or remove it if none are needed.

docs/custom_integrations.qmd (1)

53-55: Ensure integration name consistency
The display entry uses "Stablemax" but elsewhere the plugin is named StableMax. Update the section name to "StableMax" to maintain consistent casing across docs.

src/axolotl/integrations/stablemax/__init__.py (1)

3-3: Remove unused import.

The torch import is not used in this file.

-import torch
🧰 Tools
🪛 Ruff (0.11.9)

3-3: torch imported but unused

(F401)

src/axolotl/integrations/stablemax/README.md (2)

16-18: Add language specification to code block.

The fenced code block should specify the language for better rendering and syntax highlighting.

-```
+```text
 StableMax(x_i) = s(x_i) / sum_j s(x_j)

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.17.2)</summary>

16-16: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

</details>

</details>

---

`19-19`: **Consider hyphenating compound adjective.**

For better readability, consider hyphenating "floating-point" when used as a compound adjective.


```diff
-This prevents floating point absorption errors that can halt learning in grokking tasks.
+This prevents floating-point absorption errors that can halt learning in grokking tasks.
🧰 Tools
🪛 LanguageTool

[uncategorized] ~19-~19: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

src/axolotl/integrations/stablemax/stablemax.py (3)

2-2: Remove unused import.

The torch.nn.functional import is not used in this file.

-import torch.nn.functional as F
🧰 Tools
🪛 Ruff (0.11.9)

2-2: torch.nn.functional imported but unused

Remove unused import: torch.nn.functional

(F401)


32-37: Simplify control flow by removing unnecessary elif.

After a return statement, the elif is unnecessary and can be simplified to if.

     if reduction == "mean":
         return loss.mean()
-    elif reduction == "sum":
+    if reduction == "sum":
         return loss.sum()
-    else:
-        return loss
+    return loss
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 32-37: Unnecessary "elif" after "return", remove the leading "el" from "elif"

(R1705)


26-26: Consider making epsilon configurable or using a larger value.

The hardcoded epsilon 1e-12 might be too small for some numerical scenarios, especially with mixed precision training. Consider using a slightly larger value like 1e-8 or making it configurable.

-        log_probs = torch.log(probs + 1e-12)
+        log_probs = torch.log(probs + 1e-8)

Also applies to: 30-30

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cb03c76 and 038c766.

📒 Files selected for processing (8)
  • .gitignore (1 hunks)
  • docs/custom_integrations.qmd (1 hunks)
  • src/axolotl/integrations/stablemax/ACKNOWLEDGEMENTS.md (1 hunks)
  • src/axolotl/integrations/stablemax/LICENSE (1 hunks)
  • src/axolotl/integrations/stablemax/README.md (1 hunks)
  • src/axolotl/integrations/stablemax/__init__.py (1 hunks)
  • src/axolotl/integrations/stablemax/args.py (1 hunks)
  • src/axolotl/integrations/stablemax/stablemax.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/integrations/stablemax/stablemax.py (1)
src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
  • loss (56-158)
🪛 Ruff (0.11.9)
src/axolotl/integrations/stablemax/__init__.py

3-3: torch imported but unused

(F401)

src/axolotl/integrations/stablemax/stablemax.py

2-2: torch.nn.functional imported but unused

Remove unused import: torch.nn.functional

(F401)

🪛 LanguageTool
src/axolotl/integrations/stablemax/README.md

[uncategorized] ~19-~19: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

🪛 markdownlint-cli2 (0.17.2)
src/axolotl/integrations/stablemax/README.md

4-4: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


16-16: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

🪛 Pylint (3.3.7)
src/axolotl/integrations/stablemax/stablemax.py

[refactor] 32-37: Unnecessary "elif" after "return", remove the leading "el" from "elif"

(R1705)

⏰ Context from checks skipped due to timeout of 90000ms (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: preview
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (4)
.gitignore (1)

194-196: macOS resource fork ignore rules
The patterns for ._* and .DS_Store are correctly added to prevent macOS artifacts from being committed.

src/axolotl/integrations/stablemax/args.py (1)

1-10: Argument model is well-defined
StableMaxArgs correctly exposes a stablemax flag with a clear description, fitting the plugin’s configuration pattern.

src/axolotl/integrations/stablemax/README.md (1)

1-40: Excellent documentation quality.

The documentation provides clear explanations of the StableMax function, its mathematical foundation, usage instructions, and important compatibility notes. The references to the ICLR 2025 paper add credibility and context.

🧰 Tools
🪛 LanguageTool

[uncategorized] ~19-~19: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

🪛 markdownlint-cli2 (0.17.2)

4-4: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


16-16: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

src/axolotl/integrations/stablemax/stablemax.py (1)

4-11: 🛠️ Refactor suggestion

Address potential numerical instability in the piecewise function.

The current implementation has a potential numerical issue: when x approaches 1 from below, 1 / (1 - x) can become extremely large, and when x = 1 exactly, this results in division by zero.

Consider adding a small epsilon or clamping to prevent numerical instability:

 def stablemax_fn(x):
     """
     Numerically stable alternative to softmax.
     s(x) = x + 1 if x >= 0, else 1 / (1 - x)
     StableMax(x_i) = s(x_i) / sum_j s(x_j)
     """
-    s = torch.where(x >= 0, x + 1, 1 / (1 - x))
+    # Clamp x to prevent division by zero when x approaches 1
+    x_clamped = torch.clamp(x, max=0.999)
+    s = torch.where(x >= 0, x + 1, 1 / (1 - x_clamped))
     return s / s.sum(dim=-1, keepdim=True)

Likely an incorrect or invalid review comment.

Copy link

codecov bot commented Jun 5, 2025

Codecov Report

Attention: Patch coverage is 0% with 66 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/integrations/stablemax/stablemax.py 0.00% 41 Missing ⚠️
src/axolotl/integrations/stablemax/__init__.py 0.00% 22 Missing ⚠️
src/axolotl/integrations/stablemax/args.py 0.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (9)
src/axolotl/integrations/stablemax/__init__.py (2)

3-3: Remove unused import.

The torch import is not used in this file and should be removed to fix the linting error.

-import torch
🧰 Tools
🪛 Ruff (0.11.9)

3-3: torch imported but unused

(F401)

🪛 GitHub Actions: lint

[error] 3-3: flake8: 'torch' imported but unused (F401)


1-1: Add module docstring.

Add a module-level docstring to address the pylint warning and provide clear documentation about this plugin's purpose.

+"""
+StableMax integration plugin for Axolotl.
+
+This plugin provides a numerically stable alternative to softmax cross-entropy
+by globally patching torch.nn.functional.cross_entropy when enabled.
+"""
+
 # StableMax integration entry point
🧰 Tools
🪛 GitHub Actions: lint

[error] 1-1: pylint: Missing module docstring (missing-module-docstring)


[error] 1-1: pre-commit hook 'trailing-whitespace' failed and fixed trailing whitespace issues


[error] 1-1: pre-commit hook 'isort' fixed import sorting issues

src/axolotl/integrations/stablemax/README.md (3)

21-21: Use hyphenated compound adjective.

For better readability, hyphenate the compound adjective "floating-point" when it modifies a noun.

-This prevents floating point absorption errors that can halt learning in grokking tasks.
+This prevents floating-point absorption errors that can halt learning in grokking tasks.
🧰 Tools
🪛 LanguageTool

[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)


18-20: Specify language for code block.

Add a language identifier to the fenced code block for better syntax highlighting.

-```
+```text
 StableMax(x_i) = s(x_i) / sum_j s(x_j)

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.17.2)</summary>

18-18: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

</details>

</details>

---

`3-6`: **Fix blockquote formatting.**

Remove blank lines inside blockquotes to fix markdownlint warnings.


```diff
 > **⚠️ WARNING:** StableMax performs **global patching** of `torch.nn.functional.cross_entropy`, replacing it with `stablemax_cross_entropy` for ALL subsequent calls throughout the entire application. This affects not only your model training but also any other libraries, models, or code that use `torch.nn.functional.cross_entropy`.
-
 > **⚠️ COMPATIBILITY:** Do not enable StableMax simultaneously with other cross-entropy patches such as **Liger** (`liger_cross_entropy`, `liger_fused_linear_cross_entropy`) or **CutCrossEntropy** (`cut_cross_entropy`). The system will detect and prevent such conflicts, but enabling multiple patches can lead to unpredictable runtime behavior.
-
 > **Note:** StableMax is intended to be used in combination with the orthograd optimizer ([implementation here](https://github.com/cognitivecomputations/dolphinflow-optimizer)) to fully implement the solution described in Prieto et al.
🧰 Tools
🪛 markdownlint-cli2 (0.17.2)

4-4: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


6-6: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)

src/axolotl/integrations/stablemax/stablemax.py (4)

1-1: Add module docstring.

Add a module-level docstring to describe the StableMax implementation and its purpose.

+"""
+StableMax implementation: A numerically stable alternative to softmax.
+
+This module implements the StableMax function and its corresponding cross-entropy
+loss as described in "Grokking at the Edge of Numerical Stability" (Prieto et al., ICLR 2025).
+"""
+
 import torch
🧰 Tools
🪛 GitHub Actions: lint

[error] 1-1: pylint: Missing module docstring (missing-module-docstring)


13-15: Rename parameter to avoid shadowing built-in.

The parameter input shadows Python's built-in input() function. Consider using a more descriptive name.

-def stablemax_cross_entropy(input, target, weight=None, ignore_index=-100, 
+def stablemax_cross_entropy(logits, target, weight=None, ignore_index=-100, 
                            size_average=None, reduce=None, reduction="mean", 
                            label_smoothing=0.0):

And update the corresponding usage:

-    probs = stablemax_fn(input)
+    probs = stablemax_fn(logits)

23-24: Document unused legacy parameters.

The size_average and reduce parameters are kept for PyTorch compatibility but not used. Update the docstring to clarify this.

-        size_average: deprecated (kept for compatibility)
-        reduce: deprecated (kept for compatibility)
+        size_average: deprecated, ignored (kept for PyTorch compatibility)
+        reduce: deprecated, ignored (kept for PyTorch compatibility)

82-87: Simplify conditional structure.

Remove unnecessary elif statements after return to improve code readability.

     elif reduction == "mean":
         return loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=input.device)
-    elif reduction == "sum":
+    if reduction == "sum":
         return loss.sum()
-    else:
-        raise ValueError(f"Invalid reduction mode: {reduction}")
+    
+    raise ValueError(f"Invalid reduction mode: {reduction}")
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 75-87: Unnecessary "elif" after "return", remove the leading "el" from "elif"

(R1705)

🪛 GitHub Actions: lint

[error] 86-86: pylint: Unnecessary 'elif' after 'return', remove the leading 'el' from 'elif' (no-else-return)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 038c766 and a769fb8.

📒 Files selected for processing (4)
  • src/axolotl/integrations/stablemax/LICENSE (1 hunks)
  • src/axolotl/integrations/stablemax/README.md (1 hunks)
  • src/axolotl/integrations/stablemax/__init__.py (1 hunks)
  • src/axolotl/integrations/stablemax/stablemax.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/integrations/stablemax/LICENSE
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/integrations/stablemax/stablemax.py (1)
src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
  • loss (56-158)
src/axolotl/integrations/stablemax/__init__.py (3)
src/axolotl/integrations/base.py (1)
  • BasePlugin (41-238)
src/axolotl/integrations/stablemax/stablemax.py (1)
  • stablemax_cross_entropy (13-87)
src/axolotl/integrations/stablemax/args.py (1)
  • StableMaxArgs (3-10)
🪛 Pylint (3.3.7)
src/axolotl/integrations/stablemax/stablemax.py

[refactor] 75-87: Unnecessary "elif" after "return", remove the leading "el" from "elif"

(R1705)

🪛 GitHub Actions: lint
src/axolotl/integrations/stablemax/stablemax.py

[error] 1-1: pylint: Missing module docstring (missing-module-docstring)


[error] 16-16: pylint: Redefining built-in 'input' (redefined-builtin)


[error] 86-86: pylint: Unnecessary 'elif' after 'return', remove the leading 'el' from 'elif' (no-else-return)


[error] 20-20: pylint: Unused argument 'size_average' (unused-argument)


[error] 21-21: pylint: Unused argument 'reduce' (unused-argument)

src/axolotl/integrations/stablemax/__init__.py

[error] 3-3: flake8: 'torch' imported but unused (F401)


[error] 1-1: pylint: Missing module docstring (missing-module-docstring)


[error] 1-1: pre-commit hook 'trailing-whitespace' failed and fixed trailing whitespace issues


[error] 1-1: pre-commit hook 'isort' fixed import sorting issues

🪛 Ruff (0.11.9)
src/axolotl/integrations/stablemax/__init__.py

3-3: torch imported but unused

(F401)

🪛 LanguageTool
src/axolotl/integrations/stablemax/README.md

[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

🪛 markdownlint-cli2 (0.17.2)
src/axolotl/integrations/stablemax/README.md

4-4: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


6-6: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


18-18: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (3)
src/axolotl/integrations/stablemax/__init__.py (1)

16-57: LGTM! Excellent conflict detection and documentation.

The implementation properly addresses the global patching concerns with:

  • Clear warnings in the docstring about global replacement behavior
  • Comprehensive conflict detection for other cross-entropy patches (Liger, CutCrossEntropy)
  • Helpful error messages when conflicts are detected

This ensures users are aware of the implications and prevents runtime conflicts.

src/axolotl/integrations/stablemax/stablemax.py (2)

4-11: LGTM! Correct StableMax implementation.

The stablemax_fn correctly implements the piecewise function:

  • s(x) = x + 1 for x >= 0
  • s(x) = 1 / (1 - x) for x < 0
  • Proper normalization by sum along the last dimension

This matches the mathematical definition from the paper.


30-88:

❌ Incorrect review comment

Verify edge cases in cross-entropy implementation.

The implementation handles complex scenarios but some edge cases need verification:

  1. One-hot targets with ignore_index: Currently assumes ignore_index doesn't apply to one-hot targets, but this might not always be correct.
  2. Weight tensor shape assumptions: The code assumes weight is 1D in line 63 (weight.unsqueeze(0)), but PyTorch's cross_entropy accepts weights that match the target shape.
  3. Empty batch handling: When all targets are ignored, loss.numel() > 0 check prevents division by zero, which is good.

Please verify these edge cases work correctly:


🏁 Script executed:

#!/bin/bash
# Search for how PyTorch handles these edge cases in cross_entropy
rg -A 5 -B 5 "one_hot.*ignore_index|ignore_index.*one_hot" --type py

Length of output: 70


Ignore unnecessary edge-case checks for cross-entropy
The implementation aligns with PyTorch’s semantics:

  • One-hot mode inherently doesn’t support ignore_index, so assuming no mask there is correct.
  • weight must be a 1D tensor of class-level weights (PyTorch’s cross_entropy only accepts per-class weights), so weight.unsqueeze(0) is appropriate.
  • The empty-batch guard (loss.numel() > 0) correctly avoids division by zero.

Likely an incorrect or invalid review comment.

🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 75-87: Unnecessary "elif" after "return", remove the leading "el" from "elif"

(R1705)

🪛 GitHub Actions: lint

[error] 86-86: pylint: Unnecessary 'elif' after 'return', remove the leading 'el' from 'elif' (no-else-return)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
src/axolotl/integrations/stablemax/__init__.py (1)

3-3: Remove unused import.

The torch import is not used in this file and should be removed for cleaner code.

-import torch
🧰 Tools
🪛 Ruff (0.11.9)

3-3: torch imported but unused

(F401)

src/axolotl/integrations/stablemax/README.md (2)

3-6: Fix markdown formatting in warning blockquotes.

The blockquotes have formatting issues that affect readability.

-> **⚠️ WARNING:** StableMax performs **global patching** of `torch.nn.functional.cross_entropy`, replacing it with `stablemax_cross_entropy` for ALL subsequent calls throughout the entire application. This affects not only your model training but also any other libraries, models, or code that use `torch.nn.functional.cross_entropy`.
-
-> **⚠️ COMPATIBILITY:** Do not enable StableMax simultaneously with other cross-entropy patches such as **Liger** (`liger_cross_entropy`, `liger_fused_linear_cross_entropy`) or **CutCrossEntropy** (`cut_cross_entropy`). The system will detect and prevent such conflicts, but enabling multiple patches can lead to unpredictable runtime behavior.
+> **⚠️ WARNING:** StableMax performs **global patching** of `torch.nn.functional.cross_entropy`, replacing it with `stablemax_cross_entropy` for ALL subsequent calls throughout the entire application. This affects not only your model training but also any other libraries, models, or code that use `torch.nn.functional.cross_entropy`.
+>
+> **⚠️ COMPATIBILITY:** Do not enable StableMax simultaneously with other cross-entropy patches such as **Liger** (`liger_cross_entropy`, `liger_fused_linear_cross_entropy`) or **CutCrossEntropy** (`cut_cross_entropy`). The system will detect and prevent such conflicts, but enabling multiple patches can lead to unpredictable runtime behavior.
🧰 Tools
🪛 markdownlint-cli2 (0.17.2)

4-4: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


6-6: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


18-21: Add language specification to code block and fix hyphenation.

Minor formatting improvements for better documentation quality.

-```
+```text
 StableMax(x_i) = s(x_i) / sum_j s(x_j)

-This prevents floating point absorption errors that can halt learning in grokking tasks.
+This prevents floating-point absorption errors that can halt learning in grokking tasks.


<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 LanguageTool</summary>

[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

</details>
<details>
<summary>🪛 markdownlint-cli2 (0.17.2)</summary>

18-18: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

</details>

</details>

</blockquote></details>
<details>
<summary>src/axolotl/integrations/stablemax/stablemax.py (1)</summary><blockquote>

`15-102`: **Comprehensive cross-entropy implementation with excellent PyTorch compatibility.**

The function successfully addresses the previous review feedback by implementing:

✅ **Complete API compatibility** with `torch.nn.functional.cross_entropy`  
✅ **Robust target handling** for both class indices and one-hot encoded targets  
✅ **Proper ignore_index support** with valid masking  
✅ **Label smoothing** implementation  
✅ **Class weighting** for both target formats  
✅ **All reduction modes** (`none`, `mean`, `sum`) with correct tensor sizing  

The edge case handling is particularly well thought out, such as maintaining original batch size for `reduction="none"` with ignored indices (lines 88-94).



Minor style improvement for the elif chain:

```diff
-    elif reduction == "mean":
+    if reduction == "mean":
         return (
             loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=input.device)
         )
-    elif reduction == "sum":
+    if reduction == "sum":
         return loss.sum()
-    else:
+    
         raise ValueError(f"Invalid reduction mode: {reduction}")
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 86-102: Unnecessary "elif" after "return", remove the leading "el" from "elif"

(R1705)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a769fb8 and 4b5b935.

📒 Files selected for processing (4)
  • src/axolotl/integrations/stablemax/README.md (1 hunks)
  • src/axolotl/integrations/stablemax/__init__.py (1 hunks)
  • src/axolotl/integrations/stablemax/args.py (1 hunks)
  • src/axolotl/integrations/stablemax/stablemax.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/integrations/stablemax/args.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/integrations/stablemax/__init__.py (3)
src/axolotl/integrations/base.py (1)
  • BasePlugin (41-238)
src/axolotl/integrations/stablemax/args.py (1)
  • StableMaxArgs (4-12)
src/axolotl/integrations/stablemax/stablemax.py (1)
  • stablemax_cross_entropy (15-102)
src/axolotl/integrations/stablemax/stablemax.py (1)
src/axolotl/integrations/kd/topk_logprob/forward_kl.py (1)
  • loss (56-158)
🪛 LanguageTool
src/axolotl/integrations/stablemax/README.md

[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

🪛 markdownlint-cli2 (0.17.2)
src/axolotl/integrations/stablemax/README.md

4-4: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


6-6: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


18-18: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

🪛 Ruff (0.11.9)
src/axolotl/integrations/stablemax/__init__.py

3-3: torch imported but unused

(F401)

🪛 Pylint (3.3.7)
src/axolotl/integrations/stablemax/stablemax.py

[refactor] 86-102: Unnecessary "elif" after "return", remove the leading "el" from "elif"

(R1705)

⏰ Context from checks skipped due to timeout of 90000ms (7)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (3)
src/axolotl/integrations/stablemax/__init__.py (1)

31-61: Excellent conflict detection and global patching documentation.

The implementation properly addresses the concerns raised in previous reviews:

  • Clear documentation of global patching behavior with appropriate warnings
  • Comprehensive conflict detection for Liger and CutCrossEntropy integrations
  • Descriptive error messages to guide users

This provides a safe and well-documented approach to the inherently risky global patching strategy.

src/axolotl/integrations/stablemax/README.md (1)

1-42: Comprehensive and well-structured documentation.

The README provides excellent coverage of:

  • Clear warnings about global patching behavior and compatibility issues
  • Mathematical explanation of the StableMax function
  • Proper usage instructions with configuration examples
  • References to the academic paper and related resources

This documentation effectively addresses the integration's complexity and potential risks.

🧰 Tools
🪛 LanguageTool

[uncategorized] ~21-~21: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...s(x_i) / sum_j s(x_j) ``` This prevents floating point absorption errors that can halt learnin...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

🪛 markdownlint-cli2 (0.17.2)

4-4: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


6-6: Blank line inside blockquote
null

(MD028, no-blanks-blockquote)


18-18: Fenced code blocks should have a language specified
null

(MD040, fenced-code-language)

src/axolotl/integrations/stablemax/stablemax.py (1)

5-12: StableMax function implementation is mathematically correct.

The piecewise function correctly implements the paper's specification:

  • For x ≥ 0: s(x) = x + 1
  • For x < 0: s(x) = 1 / (1 - x)
  • Proper normalization by sum

The torch.where approach is efficient and the normalization maintains probability distribution properties.

@winglian
Copy link
Collaborator

winglian commented Jun 7, 2025

Can you run pre-commit run --all-files?

@ehartford
Copy link
Collaborator Author

Can you run pre-commit run --all-files?

I a bit busy and my cat died. if you want stablemax here it is.

@ehartford
Copy link
Collaborator Author

btw I think this will address the problem of the model getting stuck repeating the same token over and over.

@ehartford
Copy link
Collaborator Author

also if you use stablemax you don't need to use weight_decay

@winglian
Copy link
Collaborator

@ehartford I tried to lint/clean up this PR, but maintainers don't have access to modify this PR. Can you update access please? Unfortunately we can't merge this when we can't run basic CI against it. Thanks!

Most PRs with correct access look like this to me:
Screenshot 2025-06-18 at 12 56 58 PM

but this PR shows:
Screenshot 2025-06-18 at 12 57 14 PM

@ehartford
Copy link
Collaborator Author

sorry for the trouble.
I will give it a try

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants