Skip to content

Commit 3b471d5

Browse files
committed
Enable SFT for multimodal llama4
1 parent 86b232e commit 3b471d5

File tree

8 files changed

+353
-88
lines changed

8 files changed

+353
-88
lines changed

MaxText/configs/base.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,9 @@ expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else
409409
eval_per_device_batch_size: 0.0
410410
max_corpus_chars: 10_000_000
411411
train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
412+
train_image_column: 'image'
412413
eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
414+
eval_image_column: 'image'
413415
packing: True
414416
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
415417

@@ -732,7 +734,7 @@ dtype_mm: "float32" # Data type for multimodal model's vision encoder
732734
remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options.
733735
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
734736
image_path: "" # Local image path used for decoding
735-
737+
image_placeholder: "<|image|>"
736738

737739
### llama4 multi modal configs
738740
hidden_size_for_vit: 1408
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
base_config: "base.yml"
16+
17+
use_sft: True
18+
use_multimodal: True
19+
# For vision, the prompt contrains image, we only train on completion tokens
20+
sft_train_on_completion_only: True
21+
packing: False # packing is not supported yet
22+
freeze_vision_encoder_params: True
23+
learning_rate: 2.e-5
24+
25+
# -------------- HF pipeline --------------
26+
dataset_type: hf
27+
hf_path: 'HuggingFaceM4/ChartQA'
28+
train_split: 'train'
29+
hf_eval_split: 'val'
30+
train_data_columns: ['query', 'label'] # the first column is prompt, second column is completion
31+
eval_data_columns: ['query', 'label'] # the first column is prompt, second column is completion

MaxText/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def main(argv: Sequence[str]) -> None:
100100
prefill_length = config.max_prefill_predict_length
101101
processor_output = multimodal_utils.PreprocessorOutput()
102102
if config.use_multimodal:
103-
text = multimodal_utils.reformat_prompt(text, config.model_name)
103+
text = multimodal_utils.reformat_prompt(text, image_placeholder=config.image_placeholder, model_name=config.model_name)
104104
# TODO(hengtaoguo): Support multiple images as input.
105105
images = multimodal_utils.load_image_from_path(config.image_path)
106106
processor_output = multimodal_utils.pre_process_image(images, model_name=config.model_name)

MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 176 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,116 @@
3232
from MaxText import multihost_dataloading
3333

3434

35+
def vision_sft_preprocessing_pipeline(
36+
dataset,
37+
config,
38+
dataloading_host_index,
39+
dataloading_host_count,
40+
global_mesh,
41+
text_columns,
42+
image_column,
43+
global_batch_size,
44+
):
45+
"""pipeline for multimodal SFT with HF dataset"""
46+
47+
assert len(text_columns) == 2, f"Need two text_columns for query and response, received {text_columns=}"
48+
49+
if config.enable_data_shuffling:
50+
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
51+
52+
dataset = dataset.select_columns(text_columns + [image_column])
53+
dataset = dataset.map(
54+
_input_pipeline_utils.reformat_prompt,
55+
fn_kwargs={"column": text_columns[0], "image_placeholder": config.image_placeholder, "model_name": config.model_name},
56+
)
57+
dataset = dataset.map(
58+
_input_pipeline_utils.reformat_response,
59+
fn_kwargs={"column": text_columns[1], "model_name": config.model_name},
60+
)
61+
if image_column != "images":
62+
dataset = dataset.rename_column(image_column, "images")
63+
64+
dataset = dataset.map(
65+
_input_pipeline_utils.pre_process_image_sft,
66+
fn_kwargs={"image_column": "images", "model_name": config.model_name},
67+
)
68+
69+
tokenizer = transformers.AutoTokenizer.from_pretrained(
70+
config.tokenizer_path,
71+
add_bos_token=False,
72+
add_eos_token=False,
73+
legacy=False,
74+
token=config.hf_access_token,
75+
)
76+
if tokenizer.pad_token_id is not None:
77+
pad_id = tokenizer.pad_token_id
78+
elif tokenizer.unk_token_id is not None:
79+
pad_id = tokenizer.unk_token_id
80+
else:
81+
pad_id = -1
82+
83+
dataset = dataset.map(
84+
_input_pipeline_utils.tokenization,
85+
batched=True,
86+
fn_kwargs={
87+
"hf_tokenizer": tokenizer,
88+
"truncation": False,
89+
"max_length": config.max_target_length,
90+
"column_names": text_columns,
91+
},
92+
)
93+
dataset = dataset.map(
94+
_input_pipeline_utils.prepare_text_for_image_fusion,
95+
fn_kwargs={"column_name": text_columns[0], "model_name": config.model_name},
96+
)
97+
98+
dataset = _input_pipeline_utils.HFDataSource(
99+
dataset=dataset,
100+
dataloading_host_index=dataloading_host_index,
101+
dataloading_host_count=dataloading_host_count,
102+
num_threads=1,
103+
generate_padding_example=True,
104+
max_target_length=config.max_target_length,
105+
data_column_names=text_columns,
106+
)
107+
operations = []
108+
operations.append(
109+
_input_pipeline_utils.SFTPromptMaskingVision(
110+
query_column=text_columns[0],
111+
response_column=text_columns[1],
112+
max_target_length=config.max_target_length,
113+
unk_id=pad_id,
114+
)
115+
)
116+
# TODO(aireenmei, hengtaoguo): support packing
117+
operations.append(_input_pipeline_utils.PadToMaxLength(config.max_target_length, pad_id))
118+
operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=True))
119+
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
120+
dummy_index_sampler = grain.IndexSampler(
121+
num_records=len(dataset),
122+
num_epochs=1,
123+
shard_options=grain.ShardOptions(
124+
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False
125+
),
126+
shuffle=False,
127+
seed=0,
128+
)
129+
130+
dataloader = grain.DataLoader(
131+
data_source=dataset,
132+
operations=operations,
133+
sampler=dummy_index_sampler,
134+
worker_count=1, # only supports <=1 for now, more workers results in duplicated data
135+
worker_buffer_size=1,
136+
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=128),
137+
)
138+
139+
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)
140+
141+
# Return multi-host jax.Array prep iterator
142+
return multihost_gen
143+
144+
35145
def preprocessing_pipeline(
36146
dataloading_host_index,
37147
dataloading_host_count,
@@ -212,27 +322,39 @@ def make_hf_train_iterator(
212322
streaming=True,
213323
token=config.hf_access_token,
214324
)
215-
train_iter = preprocessing_pipeline(
216-
dataloading_host_index=process_indices_train.index(jax.process_index()),
217-
dataloading_host_count=len(process_indices_train),
218-
global_mesh=global_mesh,
219-
dataset=train_ds,
220-
data_column_names=config.train_data_columns,
221-
tokenize=config.tokenize_train_data,
222-
tokenizer_path=config.tokenizer_path,
223-
hf_access_token=config.hf_access_token,
224-
global_batch_size=config.global_batch_size_to_load,
225-
max_target_length=config.max_target_length,
226-
shuffle=config.enable_data_shuffling,
227-
data_shuffle_seed=config.data_shuffle_seed,
228-
add_bos=config.add_bos,
229-
add_eos=config.add_eos,
230-
packing=config.packing,
231-
generate_padding_example=False,
232-
use_dpo=config.use_dpo,
233-
use_sft=config.use_sft,
234-
sft_train_on_completion_only=config.sft_train_on_completion_only,
235-
)
325+
if config.use_sft and config.use_multimodal:
326+
train_iter = vision_sft_preprocessing_pipeline(
327+
dataset=train_ds,
328+
config=config,
329+
dataloading_host_index=process_indices_train.index(jax.process_index()),
330+
dataloading_host_count=len(process_indices_train),
331+
global_mesh=global_mesh,
332+
text_columns=config.train_data_columns,
333+
image_column=config.train_image_column,
334+
global_batch_size=config.global_batch_size_to_load,
335+
)
336+
else:
337+
train_iter = preprocessing_pipeline(
338+
dataloading_host_index=process_indices_train.index(jax.process_index()),
339+
dataloading_host_count=len(process_indices_train),
340+
global_mesh=global_mesh,
341+
dataset=train_ds,
342+
data_column_names=config.train_data_columns,
343+
tokenize=config.tokenize_train_data,
344+
tokenizer_path=config.tokenizer_path,
345+
hf_access_token=config.hf_access_token,
346+
global_batch_size=config.global_batch_size_to_load,
347+
max_target_length=config.max_target_length,
348+
shuffle=config.enable_data_shuffling,
349+
data_shuffle_seed=config.data_shuffle_seed,
350+
add_bos=config.add_bos,
351+
add_eos=config.add_eos,
352+
packing=config.packing,
353+
generate_padding_example=False,
354+
use_dpo=config.use_dpo,
355+
use_sft=config.use_sft,
356+
sft_train_on_completion_only=config.sft_train_on_completion_only,
357+
)
236358
return train_iter
237359

238360

@@ -252,25 +374,37 @@ def make_hf_eval_iterator(
252374
)
253375

254376
eval_generate_padding_example = config.eval_steps > 0
255-
eval_iter = preprocessing_pipeline(
256-
dataloading_host_index=process_indices_eval.index(jax.process_index()),
257-
dataloading_host_count=len(process_indices_eval),
258-
global_mesh=global_mesh,
259-
dataset=eval_ds,
260-
data_column_names=config.eval_data_columns,
261-
tokenize=config.tokenize_eval_data,
262-
tokenizer_path=config.tokenizer_path,
263-
hf_access_token=config.hf_access_token,
264-
global_batch_size=config.global_batch_size_to_load_eval,
265-
max_target_length=config.max_target_length,
266-
shuffle=False,
267-
data_shuffle_seed=config.data_shuffle_seed,
268-
add_bos=config.add_bos,
269-
add_eos=config.add_eos,
270-
packing=config.packing,
271-
generate_padding_example=eval_generate_padding_example,
272-
use_dpo=config.use_dpo,
273-
use_sft=config.use_sft,
274-
sft_train_on_completion_only=config.sft_train_on_completion_only,
275-
)
377+
if config.use_sft and config.use_multimodal:
378+
eval_iter = vision_sft_preprocessing_pipeline(
379+
dataset=eval_ds,
380+
config=config,
381+
dataloading_host_index=process_indices_eval.index(jax.process_index()),
382+
dataloading_host_count=len(process_indices_eval),
383+
global_mesh=global_mesh,
384+
text_columns=config.eval_data_columns,
385+
image_column=config.eval_image_column,
386+
global_batch_size=config.global_batch_size_to_load_eval,
387+
)
388+
else:
389+
eval_iter = preprocessing_pipeline(
390+
dataloading_host_index=process_indices_eval.index(jax.process_index()),
391+
dataloading_host_count=len(process_indices_eval),
392+
global_mesh=global_mesh,
393+
dataset=eval_ds,
394+
data_column_names=config.eval_data_columns,
395+
tokenize=config.tokenize_eval_data,
396+
tokenizer_path=config.tokenizer_path,
397+
hf_access_token=config.hf_access_token,
398+
global_batch_size=config.global_batch_size_to_load_eval,
399+
max_target_length=config.max_target_length,
400+
shuffle=False,
401+
data_shuffle_seed=config.data_shuffle_seed,
402+
add_bos=config.add_bos,
403+
add_eos=config.add_eos,
404+
packing=config.packing,
405+
generate_padding_example=eval_generate_padding_example,
406+
use_dpo=config.use_dpo,
407+
use_sft=config.use_sft,
408+
sft_train_on_completion_only=config.sft_train_on_completion_only,
409+
)
276410
return eval_iter

MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import tensorflow as tf
2828
from MaxText import max_logging
2929
from MaxText import tokenizer
30+
from MaxText import multimodal_utils
3031

3132
Features = Dict[str, tf.Tensor]
3233
AUTOTUNE = tf.data.experimental.AUTOTUNE
@@ -68,6 +69,37 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):
6869
########## Functions used by HF pipeline
6970

7071

72+
def reformat_prompt(example, column, image_placeholder, model_name):
73+
"""reformat prompt for multimodal SFT"""
74+
example[column] = multimodal_utils.reformat_prompt(example[column], image_placeholder, model_name)
75+
return example
76+
77+
78+
def reformat_response(example, column, model_name):
79+
"""reformat response for multimodal SFT"""
80+
example[column] = multimodal_utils.reformat_response(example[column], model_name)
81+
return example
82+
83+
84+
def pre_process_image_sft(example, image_column, model_name):
85+
"""pre-process image for multimodal SFT"""
86+
image = multimodal_utils.convert_to_RGB(example[image_column])
87+
# TODO(aireenmei, hengtaoguo): add support for different image sizes
88+
image = multimodal_utils.resize_image(image, model_name)
89+
image = np.array(image)
90+
example[image_column] = multimodal_utils.pre_process_image(image, model_name)
91+
return example
92+
93+
94+
def prepare_text_for_image_fusion(example, column_name, model_name):
95+
"""prepare text for image fusion for multimodal SFT"""
96+
example[column_name] = multimodal_utils.prepare_text_for_image_fusion(
97+
example[column_name], model_name, processor_output=example["images"]
98+
)
99+
example["images"] = example["images"].pixel_values
100+
return example
101+
102+
71103
def combine_columns(example, columns, data_column):
72104
"""Combine columns such as 'prompt' and 'completion' for sft training"""
73105
assert len(columns) > 1
@@ -192,6 +224,26 @@ def map(self, element):
192224
}
193225

194226

227+
@dataclasses.dataclass
228+
class SFTPromptMaskingVision(grain.MapTransform):
229+
"""SFT prompt masking for multimodal"""
230+
231+
def __init__(self, query_column, response_column, max_target_length, unk_id):
232+
self.query_column = query_column
233+
self.response_column = response_column
234+
self.max_target_length = max_target_length
235+
self.unk_id = unk_id
236+
237+
def map(self, element):
238+
inputs = np.concatenate((element[self.query_column], element[self.response_column]))
239+
targets = np.concatenate((np.asarray([self.unk_id] * len(element[self.query_column])), element[self.response_column]))
240+
return {
241+
"inputs": np.asarray(inputs[: self.max_target_length], dtype=np.int32),
242+
"targets": np.asarray(targets[: self.max_target_length], dtype=np.int32),
243+
"images": element["images"],
244+
}
245+
246+
195247
@dataclasses.dataclass
196248
class HFNormalizeFeatures(grain.MapTransform):
197249
"""Normalize feature keys for HuggingFace input"""
@@ -413,10 +465,12 @@ def _pad(x, max_length, pad_id):
413465

414466
data_columns = list(element.keys())
415467
for data_column in data_columns:
416-
element[f"{data_column}_segmentation"] = (element[data_column] != self.pad_id).astype(np.int32)
417-
element[f"{data_column}_position"] = np.arange(element[data_column].shape[0], dtype=np.int32)
468+
if data_column != "images":
469+
element[f"{data_column}_segmentation"] = (element[data_column] != self.pad_id).astype(np.int32)
470+
element[f"{data_column}_position"] = np.arange(element[data_column].shape[0], dtype=np.int32)
418471
for key, _ in element.items():
419-
element[key] = _pad(element[key], self.max_length, self.pad_id)
472+
if key != "images":
473+
element[key] = _pad(element[key], self.max_length, self.pad_id)
420474
return element
421475

422476

0 commit comments

Comments
 (0)