Skip to content

Llama3-70b: Full Finetune w/CPU offload + fused optimizer #993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Llama3 70B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --hf-token <HF_TOKEN> --ignore-patterns "original/consolidated*"
#
# To launch on 8 devices, run the following command from root:
# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config is only tested on an 8xA100 machine.


# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3.llama3_70b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files: [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😢

model-00001-of-00030.safetensors,
model-00002-of-00030.safetensors,
model-00003-of-00030.safetensors,
model-00004-of-00030.safetensors,
model-00005-of-00030.safetensors,
model-00006-of-00030.safetensors,
model-00007-of-00030.safetensors,
model-00008-of-00030.safetensors,
model-00009-of-00030.safetensors,
model-00010-of-00030.safetensors,
model-00011-of-00030.safetensors,
model-00012-of-00030.safetensors,
model-00013-of-00030.safetensors,
model-00014-of-00030.safetensors,
model-00015-of-00030.safetensors,
model-00016-of-00030.safetensors,
model-00017-of-00030.safetensors,
model-00018-of-00030.safetensors,
model-00019-of-00030.safetensors,
model-00020-of-00030.safetensors,
model-00021-of-00030.safetensors,
model-00022-of-00030.safetensors,
model-00023-of-00030.safetensors,
model-00024-of-00030.safetensors,
model-00025-of-00030.safetensors,
model-00026-of-00030.safetensors,
model-00027-of-00030.safetensors,
model-00028-of-00030.safetensors,
model-00029-of-00030.safetensors,
model-00030-of-00030.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3-70b
model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3

optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm: this is a necessary change to get 70B full finetune runnable on 8x A100? Or is it more for speedup of CPU offload?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a speedup only change (will clarify the amount of speed up we get once I re-run some benchmark)


loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1


# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True
fsdp_cpu_offload: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: False
17 changes: 17 additions & 0 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch import nn
from torch.distributed import init_process_group
from torch.distributed.fsdp import (
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
Expand Down Expand Up @@ -103,6 +104,15 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

if (
cfg.get("fsdp_cpu_offload", False)
and cfg.get("fused", False)
and not utils.torch_version_ge("2.4.0")
):
raise RuntimeError(
"Using fused optimizer on CPU is only supported in PyTorch nightly."
)

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down Expand Up @@ -186,6 +196,7 @@ def setup(self, cfg: DictConfig) -> None:
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
model_state_dict=ckpt_dict[utils.MODEL_KEY],
ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
Expand Down Expand Up @@ -234,6 +245,7 @@ def _setup_model(
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
memory_efficient_fsdp_wrap: bool,
fsdp_cpu_offload: bool,
model_state_dict: Dict[str, Any],
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
Expand Down Expand Up @@ -296,6 +308,7 @@ def _setup_model(
memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap,
modules_to_wrap={modules.TransformerDecoderLayer},
),
cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
Expand Down Expand Up @@ -563,6 +576,10 @@ def recipe_main(cfg: DictConfig) -> None:
)

init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
if cfg.get("fsdp_cpu_offload", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again not a fan of this being in the recipe code, this should likely be in a utility or something which the user doesn't probably need to reason about.

cc: @ebsmothers to help brainstorm this a bit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah +1 can we put this in a utility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definitely!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this comment. Also I don't fully understand the code comment just below: do we only need to set_num_threads in the case of fused optimizer? Looking at the answer.ai code seems like the answer is no, but based on the code comment it's not clear to me

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We see the biggest gain setting this when fused optimizer is used as it controls threads used for intra-op parallelism, and CPU optimizer is the heaviest CPU op. We may have very slight speed ups when not using fused optimizer, but those can be investigated separately of this PR.

# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
utils.set_torch_num_threads()

config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)

Expand Down
2 changes: 1 addition & 1 deletion tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
gen_log_file_name,
get_loss_values_from_metric_logger,
TOKENIZER_PATHS,
torch_version_ge,
)
from torchtune import config
from torchtune.utils import torch_version_ge


class TestLoRAFinetuneSingleDeviceRecipe:
Expand Down
7 changes: 0 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@
}


def torch_version_ge(version: str) -> bool:
"""
Check if torch version is greater than or equal to the given version
"""
return version in torch.__version__ or torch.__version__ >= version


# Inherit from SentencePieceTokenizer class to reuse its tokenize_messages method
class DummyTokenizer(SentencePieceTokenizer):
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Recipe:
Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"),
Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"),
Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"),
Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obv not necessary for this particular PR, but just a (maybe controversial) thought: we could consider adding additional metadata to config and/or recipe dataclasses to capture dependencies (like specific torch/ao versions or certain optional deps). A bit tricky if users want to copy and modify, but this way we would validate in here and not have to do checks in recipes (or wherever else). Thoughts? @kartikayk @joecummings @pbontrager @RdoubleA

Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"),
Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"),
Config(name="phi3/mini_full", file_path="phi3/mini_full.yaml"),
Expand Down
3 changes: 3 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
is_distributed,
lora_fsdp_wrap_policy,
prepare_model_for_fsdp_with_meta_device,
set_torch_num_threads,
validate_no_params_on_meta_device,
)
from ._generation import generate
from ._profiler import profiler
from ._version import torch_version_ge
from .argparse import TuneRecipeArgumentParser
from .collate import padded_collate, padded_collate_dpo
from .constants import ( # noqa
Expand Down Expand Up @@ -79,6 +81,7 @@
"set_seed",
"validate_expected_param_dtype",
"TuneRecipeArgumentParser",
"torch_version_ge",
"OptimizerInBackwardWrapper",
"create_optim_in_bwd_wrapper",
"register_optim_in_bwd_hooks",
Expand Down
16 changes: 16 additions & 0 deletions torchtune/utils/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ def init_distributed(**kwargs: Dict) -> bool: # noqa: DOC106, DOC109
return False


def set_torch_num_threads() -> None:
"""
Sets the number of threads used by torch to utilize all physical CPU
cores for intra-op parallelism. Currently, this function sets num_threads
to be the number of physical CPU cores divided by the number of GPUs as we
use one process per GPU, and this avoids CPU oversubscription. Note that this is
currently a rough approximation, and doesn't take into account environments where
things like CPU affinity is set.
"""
num_threads = os.cpu_count() // (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
)
torch.set_num_threads(num_threads)
_log.info(f"Set intra op parallelism no. of threads to {num_threads}")


def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current trainer.
Expand Down
13 changes: 13 additions & 0 deletions torchtune/utils/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch


def torch_version_ge(version: str) -> bool:
"""
Check if torch version is greater than or equal to the given version
"""
return version in torch.__version__ or torch.__version__ >= version
Loading