Skip to content

Add Selective Activation Checkpointing #785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 23, 2024
Merged

Conversation

lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented Apr 17, 2024

Context

This PR updates activation checkpointing (ac) to support selective layer and selective op activation checkpointing.
It preserves the previous options enabled of full or None.
This is controlled in the yaml file via:
enable_activation_checkpointing: bool
ac_mode: ['full', 'selective']
ac_option: [int, 'op']

if ac_mode is selective then the type of selective is determined based on ac_option.
An integer represents checkpoint every x'th layer (i.e. 2 = checkpoint every other layer, 3 = every third, etc).
'op' means to run selective op ac, where the ac is filtered by the op policy.

Generically on llama-13B, selective AC 2 (every other layer) improved throughput +10% over No AC.

I updated the testing for llama3-8B where I tried to adjust the batch size under each setting to hit around 91GB. I used 8 gpus with the idea of having less impact from model params and more finer grained tuning of the bs size and thus activations. This is not always perfectly possible as activations are chunky, but the net was selective AC 3 was the highest throughput followed by No AC. Sel AC3 was +9% better throughput vs the original Full only option.
For A100, 4090s etc. the actual best combo will vary but the point here is that selective AC provides generally better throughput options over the simple binary of Full (True/False).

Screenshot 2024-04-19 at 1 01 53 PM

Changelog

  • ...

Test plan

This code is largely a port from original source in torchtitan where it has already been tested. However, I ran all 4 styles (none, full, sel ac op, sel ac 2) as shown above.
I also verified that the new impl of full matched the memory savings of the previous impl of full.

Copy link

pytorch-bot bot commented Apr 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/785

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1412378 with merge base a79554e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 17, 2024
Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some noob questions, but starting to look good!

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for patiently addressing all of the comments!

Kartikay Khandelwal added 2 commits April 22, 2024 20:19
@kartikayk kartikayk merged commit 68f2538 into pytorch:main Apr 23, 2024
# tune download meta-llama/Meta-Llama-3-8B --output-dir /tmp/Meta-Llama-3-8B --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Launch command does not seem correct

# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as (9)

# Memory management
enable_activation_checkpointing: False
ac_mode: 'selective' # ['selective', 'full']
ac_option: 2 # [int] = ac every positive int layer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting way to configure and IMO seems a little brittle. Is there any guidance around choosing this setting for users, and what if user selects an integer that exceeds the number of layers in the model?

# the older version of AC and this behavior is unchanged
# ac_mode and ac_option together control selective AC. This is only enabled
# when these are set AND ``enable_activation_checkpointing`` is set to False
# We'll clean this up as soon as testing of AC is complete
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue and an owner?

ac_mode = ac_mode
ac_option = ac_option

if (not enable_activation_checkpointing) and (ac_mode is not None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand why its this way but still reads a little weird. Like we have a check for not enable_activation_checkpointing, and underneath the check we apply AC (granted it is selective).

model (nn.Module): Model to setup activation checkpointing.
ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective']
ac_option (Optional[Union[int, str]]): Activation checkpointing option.
- If ac_mode is 'selective', ac_option can be an integer or a string
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need both int and string?

ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective']
ac_option (Optional[Union[int, str]]): Activation checkpointing option.
- If ac_mode is 'selective', ac_option can be an integer or a string
representing the number of layers to checkpoint.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it number of layers to checkpoint or skip # of layers in between checkpointed layers?

- If ac_mode is 'none' or 'full, ac_option is ignored.
"""

for layer_id, transformer_block in enumerate(model.layers):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document if we're assuming the passed in module requires a layers attribute.

for layer_id, transformer_block in enumerate(model.layers):
if ac_mode in ("full", "selective"):

transformer_block = checkpoint_wrapper(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This in place modification is quite risky and brittle in general - for example, will have to carefully use it with FSDP and other wrapping APIs. Ideally we should do this via hooks, but of course this will require a lot of refactoring around our activation checkpointing infra.

"""

for layer_id, transformer_block in enumerate(model.layers):
if ac_mode in ("full", "selective"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silently not applying AC if ac_mode is not within this tuple seems bad, especially if its undocumented?

checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.utils.checkpoint import checkpoint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused I think? Not sure why linter did not pick this up

"""
every_x_layer = int(ac_style)

if not (every_x_layer >= 0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every_x_layer < 0 is simpler?


checkpoint_wrapper.__dict__.setdefault("_count", 0)

checkpoint_wrapper._count += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think the bump should be done after the check

checkpoint_wrapper.__dict__.setdefault("_count", 0)

checkpoint_wrapper._count += 1
if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every_x_layer == 0 seems better than if not every_x_layer to guard against every_x_layer being None which is an unexpected state for this variable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants