-
Notifications
You must be signed in to change notification settings - Fork 647
Add Selective Activation Checkpointing #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/785
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1412378 with merge base a79554e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some noob questions, but starting to look good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for patiently addressing all of the comments!
# tune download meta-llama/Meta-Llama-3-8B --output-dir /tmp/Meta-Llama-3-8B --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Launch command does not seem correct
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as (9)
# Memory management | ||
enable_activation_checkpointing: False | ||
ac_mode: 'selective' # ['selective', 'full'] | ||
ac_option: 2 # [int] = ac every positive int layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting way to configure and IMO seems a little brittle. Is there any guidance around choosing this setting for users, and what if user selects an integer that exceeds the number of layers in the model?
# the older version of AC and this behavior is unchanged | ||
# ac_mode and ac_option together control selective AC. This is only enabled | ||
# when these are set AND ``enable_activation_checkpointing`` is set to False | ||
# We'll clean this up as soon as testing of AC is complete |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an issue and an owner?
ac_mode = ac_mode | ||
ac_option = ac_option | ||
|
||
if (not enable_activation_checkpointing) and (ac_mode is not None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand why its this way but still reads a little weird. Like we have a check for not enable_activation_checkpointing
, and underneath the check we apply AC (granted it is selective).
model (nn.Module): Model to setup activation checkpointing. | ||
ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective'] | ||
ac_option (Optional[Union[int, str]]): Activation checkpointing option. | ||
- If ac_mode is 'selective', ac_option can be an integer or a string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need both int and string?
ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective'] | ||
ac_option (Optional[Union[int, str]]): Activation checkpointing option. | ||
- If ac_mode is 'selective', ac_option can be an integer or a string | ||
representing the number of layers to checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it number of layers to checkpoint or skip # of layers in between checkpointed layers?
- If ac_mode is 'none' or 'full, ac_option is ignored. | ||
""" | ||
|
||
for layer_id, transformer_block in enumerate(model.layers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should document if we're assuming the passed in module requires a layers
attribute.
for layer_id, transformer_block in enumerate(model.layers): | ||
if ac_mode in ("full", "selective"): | ||
|
||
transformer_block = checkpoint_wrapper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This in place modification is quite risky and brittle in general - for example, will have to carefully use it with FSDP and other wrapping APIs. Ideally we should do this via hooks, but of course this will require a lot of refactoring around our activation checkpointing infra.
""" | ||
|
||
for layer_id, transformer_block in enumerate(model.layers): | ||
if ac_mode in ("full", "selective"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silently not applying AC if ac_mode
is not within this tuple seems bad, especially if its undocumented?
checkpoint_wrapper as ptd_checkpoint_wrapper, | ||
CheckpointImpl, | ||
) | ||
from torch.utils.checkpoint import checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused I think? Not sure why linter did not pick this up
""" | ||
every_x_layer = int(ac_style) | ||
|
||
if not (every_x_layer >= 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
every_x_layer < 0 is simpler?
|
||
checkpoint_wrapper.__dict__.setdefault("_count", 0) | ||
|
||
checkpoint_wrapper._count += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think the bump should be done after the check
checkpoint_wrapper.__dict__.setdefault("_count", 0) | ||
|
||
checkpoint_wrapper._count += 1 | ||
if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
every_x_layer == 0
seems better than if not every_x_layer
to guard against every_x_layer
being None which is an unexpected state for this variable
Context
This PR updates activation checkpointing (ac) to support selective layer and selective op activation checkpointing.
It preserves the previous options enabled of full or None.
This is controlled in the yaml file via:
enable_activation_checkpointing: bool
ac_mode: ['full', 'selective']
ac_option: [int, 'op']
if ac_mode is selective then the type of selective is determined based on ac_option.
An integer represents checkpoint every x'th layer (i.e. 2 = checkpoint every other layer, 3 = every third, etc).
'op' means to run selective op ac, where the ac is filtered by the op policy.
Generically on llama-13B, selective AC 2 (every other layer) improved throughput +10% over No AC.
I updated the testing for llama3-8B where I tried to adjust the batch size under each setting to hit around 91GB. I used 8 gpus with the idea of having less impact from model params and more finer grained tuning of the bs size and thus activations. This is not always perfectly possible as activations are chunky, but the net was selective AC 3 was the highest throughput followed by No AC. Sel AC3 was +9% better throughput vs the original Full only option.
For A100, 4090s etc. the actual best combo will vary but the point here is that selective AC provides generally better throughput options over the simple binary of Full (True/False).
Changelog
Test plan
This code is largely a port from original source in torchtitan where it has already been tested. However, I ran all 4 styles (none, full, sel ac op, sel ac 2) as shown above.
I also verified that the new impl of full matched the memory savings of the previous impl of full.