Skip to content

Allow subselecting the appropriate config for llama4 #1815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 5, 2025
Merged
28 changes: 26 additions & 2 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,21 @@ class BaseHuggingFaceModel(HuggingFaceModel):
"""Wrapper around HuggingFaceModel.

Base class for HuggingFace based models.

Attributes:
model_cls (type): The model class to use. Default: ``AutoModelForCausalLM``.
subselect_config_attr (optional, str): The attribute to use to subselect the config.
This is used if you want to select only using the text_config or vision_config
for a multimodal model. For example, AutoConfig.from_pretrained on Llama4 produces
a Llama4Config, and to use as a causal LM, we need to get the Llama4TextConfig.
Default: ``None``, which will use whatever AutoConfig produces.
default_train_metrics (tuple): The default training metrics to use.
default_eval_metrics (tuple): The default evaluation metrics to use.
"""

model_cls: Union[type[_BaseAutoModelClass],
type[PreTrainedModel]] = AutoModelForCausalLM
subselect_config_attr: Optional[str] = None
default_train_metrics: tuple = ()
default_eval_metrics: tuple = ()

Expand Down Expand Up @@ -171,16 +182,29 @@ def build_config(
attn_implementation: str,
config_overrides: dict[str, Any],
) -> PretrainedConfig:
# Necessary due to https://github.com/huggingface/transformers/issues/28056
use_cache = False

config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=attn_implementation,
torch_dtype=_MASTER_WEIGHTS_PRECISION,
use_cache=
False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
use_cache=use_cache,
)

if cls.subselect_config_attr is not None and hasattr(
config,
cls.subselect_config_attr,
):
config = getattr(config, cls.subselect_config_attr)

# Forward the above overrides to the subselected config too
config.use_cache = use_cache
config.attn_implementation = attn_implementation
config.torch_dtype = _MASTER_WEIGHTS_PRECISION

set_config_overrides(config, config_overrides)

return config
Expand Down
5 changes: 5 additions & 0 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class ComposerHFCausalLM(BaseHuggingFaceModel):

model_cls: Union[type[_BaseAutoModelClass],
type[PreTrainedModel]] = AutoModelForCausalLM

# The text_config attr should be correct for most multimodal models, although
# there is not an official standard for this and this may need to be updated in future
# transformers versions.
subselect_config_attr: Optional[str] = 'text_config'
default_train_metrics: tuple = tuple(DEFAULT_CAUSAL_LM_TRAIN_METRICS)
default_eval_metrics: tuple = tuple(DEFAULT_CAUSAL_LM_EVAL_METRICS)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
'onnx==1.17.0',
'onnxruntime==1.19.2',
'boto3>=1.21.45,<2',
'huggingface-hub>=0.19.0,<0.31',
'huggingface-hub[hf_xet]>=0.30.0,<0.31',
'beautifulsoup4>=4.12.2,<5', # required for model download utils
'tenacity>=8.2.3,<10',
'catalogue>=2,<3',
Expand Down
97 changes: 97 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,93 @@ def tiny_codellama_config_helper(tie_word_embeddings: bool = False):
return config_object


def tiny_llama4_config_helper():
pytest.importorskip('transformers')
from transformers.models.llama4.configuration_llama4 import Llama4Config

config_dict = {
'architectures': ['Llama4ForConditionalGeneration',],
'boi_token_index': 200080,
'eoi_token_index': 200081,
'image_token_index': 200092,
'model_type': 'llama4',
'text_config': {
'_attn_implementation_autoset': True,
'attention_bias': False,
'attention_chunk_size': 8192,
'attention_dropout': 0.0,
'bos_token_id': 200000,
'eos_token_id': [
200001,
200007,
200008,
],
'for_llm_compressor': False,
'head_dim': 128,
'hidden_act': 'silu',
'hidden_size': 5120,
'initializer_range': 0.02,
'interleave_moe_layer_step': 1,
'intermediate_size': 8192,
'intermediate_size_mlp': 16384,
'max_position_embeddings': 10485760,
'model_type': 'llama4_text',
'no_rope_layers': [],
'num_attention_heads': 40,
'num_experts_per_tok': 1,
'num_hidden_layers': 48,
'num_key_value_heads': 8,
'num_local_experts': 16,
'output_router_logits': False,
'pad_token_id': 200018,
'rms_norm_eps': 1e-05,
'rope_scaling': {
'factor': 16.0,
'high_freq_factor': 1.0,
'low_freq_factor': 1.0,
'original_max_position_embeddings': 8192,
'rope_type': 'llama3',
},
'rope_theta': 500000.0,
'router_aux_loss_coef': 0.001,
'router_jitter_noise': 0.0,
'torch_dtype': 'bfloat16',
'use_cache': True,
'use_qk_norm': True,
'vocab_size': 202048,
},
'torch_dtype': 'bfloat16',
'transformers_version': '4.51.0.dev0',
'vision_config': {
'_attn_implementation_autoset': True,
'attention_dropout': 0.0,
'hidden_act': 'gelu',
'hidden_size': 1408,
'image_size': 336,
'initializer_range': 0.02,
'intermediate_size': 5632,
'model_type': 'llama4_vision_model',
'multi_modal_projector_bias': False,
'norm_eps': 1e-05,
'num_attention_heads': 16,
'num_channels': 3,
'num_hidden_layers': 34,
'patch_size': 14,
'pixel_shuffle_ratio': 0.5,
'projector_dropout': 0.0,
'projector_input_dim': 4096,
'projector_output_dim': 4096,
'rope_theta': 10000,
'vision_feature_layer': -1,
'vision_feature_select_strategy': 'default',
'vision_output_dim': 4096,
},
}

config_object = Llama4Config(**config_dict)
return config_object


def tiny_bert_config_helper():
pytest.importorskip('transformers')
from transformers.models.bert.configuration_bert import BertConfig
Expand Down Expand Up @@ -318,6 +405,11 @@ def _session_tiny_bert_config(): # type: ignore
return tiny_bert_config_helper()


@pytest.fixture(scope='session')
def _session_tiny_llama4_config(): # type: ignore
return tiny_llama4_config_helper()


## SESSION TOKENIZERS ##
@pytest.fixture(scope='session')
def _session_tiny_gpt2_tokenizer(tokenizers_assets): # type: ignore
Expand Down Expand Up @@ -388,6 +480,11 @@ def tiny_bert_config(_session_tiny_bert_config): # type: ignore
return copy.deepcopy(_session_tiny_bert_config)


@pytest.fixture
def tiny_llama4_config(_session_tiny_llama4_config): # type: ignore
return copy.deepcopy(_session_tiny_llama4_config)


## TOKENIZER FIXTURES ##
@pytest.fixture
def tiny_gpt2_tokenizer(_session_tiny_gpt2_tokenizer): # type: ignore
Expand Down
23 changes: 23 additions & 0 deletions tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig

from llmfoundry.models.hf import BaseHuggingFaceModel
from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.builders import build_composer_model
Expand Down Expand Up @@ -434,3 +436,24 @@ def test_attn_implementation_none(tiny_llama_save_dir: Path):

# llama config uses _attn_implementation
assert model.config._attn_implementation == 'eager' # type: ignore


def test_text_config(tiny_llama4_config: PretrainedConfig, tmp_path: Path):
save_path = tmp_path / 'model'
tiny_llama4_config.save_pretrained(save_path)

class TestModel(BaseHuggingFaceModel):
subselect_config_attr = 'text_config'

config = TestModel.build_config(
pretrained_model_name_or_path=str(save_path),
trust_remote_code=True,
use_auth_token=False,
attn_implementation='eager',
config_overrides={},
)

assert isinstance(config, Llama4TextConfig)
assert config.attn_implementation == 'eager'
assert config.use_cache == False
assert config.torch_dtype == 'float32'
Loading