-
Notifications
You must be signed in to change notification settings - Fork 647
Llama3-70b: Full Finetune w/CPU offload + fused optimizer #993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b06f035
b0f9f5b
b3b3cad
3342996
3788015
077e805
8658de4
f9659cd
6325e69
6dea86c
293c913
cb4e311
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Config for multi-device full finetuning in full_finetune_distributed.py | ||
# using a Llama3 70B Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --hf-token <HF_TOKEN> --ignore-patterns "original/consolidated*" | ||
# | ||
# To launch on 8 devices, run the following command from root: | ||
# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config is only tested on an 8xA100 machine. | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3.llama3_70b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct | ||
checkpoint_files: [ | ||
model-00001-of-00030.safetensors, | ||
model-00002-of-00030.safetensors, | ||
model-00003-of-00030.safetensors, | ||
model-00004-of-00030.safetensors, | ||
model-00005-of-00030.safetensors, | ||
model-00006-of-00030.safetensors, | ||
model-00007-of-00030.safetensors, | ||
model-00008-of-00030.safetensors, | ||
model-00009-of-00030.safetensors, | ||
model-00010-of-00030.safetensors, | ||
model-00011-of-00030.safetensors, | ||
model-00012-of-00030.safetensors, | ||
model-00013-of-00030.safetensors, | ||
model-00014-of-00030.safetensors, | ||
model-00015-of-00030.safetensors, | ||
model-00016-of-00030.safetensors, | ||
model-00017-of-00030.safetensors, | ||
model-00018-of-00030.safetensors, | ||
model-00019-of-00030.safetensors, | ||
model-00020-of-00030.safetensors, | ||
model-00021-of-00030.safetensors, | ||
model-00022-of-00030.safetensors, | ||
model-00023-of-00030.safetensors, | ||
model-00024-of-00030.safetensors, | ||
model-00025-of-00030.safetensors, | ||
model-00026-of-00030.safetensors, | ||
model-00027-of-00030.safetensors, | ||
model-00028-of-00030.safetensors, | ||
model-00029-of-00030.safetensors, | ||
model-00030-of-00030.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Meta-Llama-3-70b | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 3 | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
lr: 2e-5 | ||
foreach: False | ||
# Note: highly recommended to use fused=True optimizer flag | ||
# with CPU offload for faster optimizer step. | ||
fused: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to confirm: this is a necessary change to get 70B full finetune runnable on 8x A100? Or is it more for speedup of CPU offload? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a speedup only change (will clarify the amount of speed up we get once I re-run some benchmark) |
||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
memory_efficient_fsdp_wrap: True | ||
fsdp_cpu_offload: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-llama3-finetune | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from torch import nn | ||
from torch.distributed import init_process_group | ||
from torch.distributed.fsdp import ( | ||
CPUOffload, | ||
FullOptimStateDictConfig, | ||
FullStateDictConfig, | ||
FullyShardedDataParallel as FSDP, | ||
|
@@ -103,6 +104,15 @@ def __init__(self, cfg: DictConfig) -> None: | |
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." | ||
) | ||
|
||
if ( | ||
cfg.get("fsdp_cpu_offload", False) | ||
and cfg.get("fused", False) | ||
and not utils.torch_version_ge("2.4.0") | ||
): | ||
raise RuntimeError( | ||
"Using fused optimizer on CPU is only supported in PyTorch nightly." | ||
) | ||
|
||
# logging attributes | ||
self._output_dir = cfg.output_dir | ||
self._log_every_n_steps = cfg.get("log_every_n_steps", 1) | ||
|
@@ -186,6 +196,7 @@ def setup(self, cfg: DictConfig) -> None: | |
cfg_model=cfg.model, | ||
enable_activation_checkpointing=cfg.enable_activation_checkpointing, | ||
memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), | ||
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), | ||
model_state_dict=ckpt_dict[utils.MODEL_KEY], | ||
ac_mode=cfg.get("ac_mode", None), | ||
ac_option=cfg.get("ac_option", None), | ||
|
@@ -234,6 +245,7 @@ def _setup_model( | |
cfg_model: DictConfig, | ||
enable_activation_checkpointing: bool, | ||
memory_efficient_fsdp_wrap: bool, | ||
fsdp_cpu_offload: bool, | ||
model_state_dict: Dict[str, Any], | ||
ac_mode: Optional[str] = None, | ||
ac_option: Optional[int] = None, | ||
|
@@ -296,6 +308,7 @@ def _setup_model( | |
memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, | ||
modules_to_wrap={modules.TransformerDecoderLayer}, | ||
), | ||
cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), | ||
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, | ||
device_id=self._device, | ||
# this recipe does not currently support mixed precision training | ||
|
@@ -563,6 +576,10 @@ def recipe_main(cfg: DictConfig) -> None: | |
) | ||
|
||
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") | ||
if cfg.get("fsdp_cpu_offload", False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again not a fan of this being in the recipe code, this should likely be in a utility or something which the user doesn't probably need to reason about. cc: @ebsmothers to help brainstorm this a bit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah +1 can we put this in a utility? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, definitely! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bumping this comment. Also I don't fully understand the code comment just below: do we only need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We see the biggest gain setting this when fused optimizer is used as it controls threads used for intra-op parallelism, and CPU optimizer is the heaviest CPU op. We may have very slight speed ups when not using fused optimizer, but those can be investigated separately of this PR. |
||
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x | ||
# speed up when benchmarking fused AdamW on CPU | ||
utils.set_torch_num_threads() | ||
|
||
config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ class Recipe: | |
Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"), | ||
Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"), | ||
Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"), | ||
Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Obv not necessary for this particular PR, but just a (maybe controversial) thought: we could consider adding additional metadata to config and/or recipe dataclasses to capture dependencies (like specific torch/ao versions or certain optional deps). A bit tricky if users want to copy and modify, but this way we would validate in here and not have to do checks in recipes (or wherever else). Thoughts? @kartikayk @joecummings @pbontrager @RdoubleA |
||
Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"), | ||
Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"), | ||
Config(name="phi3/mini_full", file_path="phi3/mini_full.yaml"), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import torch | ||
|
||
|
||
def torch_version_ge(version: str) -> bool: | ||
""" | ||
Check if torch version is greater than or equal to the given version | ||
""" | ||
return version in torch.__version__ or torch.__version__ >= version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😢