32
32
from MaxText import multihost_dataloading
33
33
34
34
35
+ def vision_sft_preprocessing_pipeline (
36
+ dataset ,
37
+ config ,
38
+ dataloading_host_index ,
39
+ dataloading_host_count ,
40
+ global_mesh ,
41
+ text_columns ,
42
+ image_column ,
43
+ global_batch_size ,
44
+ ):
45
+ """pipeline for multimodal SFT with HF dataset"""
46
+
47
+ assert len (text_columns ) == 2 , f"Need two text_columns for query and response, received { text_columns = } "
48
+
49
+ if config .enable_data_shuffling :
50
+ dataset = dataset .shuffle (seed = config .data_shuffle_seed )
51
+
52
+ dataset = dataset .select_columns (text_columns + [image_column ])
53
+ dataset = dataset .map (
54
+ _input_pipeline_utils .reformat_prompt ,
55
+ fn_kwargs = {"column" : text_columns [0 ], "image_placeholder" : config .image_placeholder , "model_name" : config .model_name },
56
+ )
57
+ dataset = dataset .map (
58
+ _input_pipeline_utils .reformat_response ,
59
+ fn_kwargs = {"column" : text_columns [1 ], "model_name" : config .model_name },
60
+ )
61
+ if image_column != "images" :
62
+ dataset = dataset .rename_column (image_column , "images" )
63
+
64
+ dataset = dataset .map (
65
+ _input_pipeline_utils .pre_process_image_sft ,
66
+ fn_kwargs = {"image_column" : "images" , "model_name" : config .model_name },
67
+ )
68
+
69
+ tokenizer = transformers .AutoTokenizer .from_pretrained (
70
+ config .tokenizer_path ,
71
+ add_bos_token = False ,
72
+ add_eos_token = False ,
73
+ legacy = False ,
74
+ token = config .hf_access_token ,
75
+ )
76
+ if tokenizer .pad_token_id is not None :
77
+ pad_id = tokenizer .pad_token_id
78
+ elif tokenizer .unk_token_id is not None :
79
+ pad_id = tokenizer .unk_token_id
80
+ else :
81
+ pad_id = - 1
82
+
83
+ dataset = dataset .map (
84
+ _input_pipeline_utils .tokenization ,
85
+ batched = True ,
86
+ fn_kwargs = {
87
+ "hf_tokenizer" : tokenizer ,
88
+ "truncation" : False ,
89
+ "max_length" : config .max_target_length ,
90
+ "column_names" : text_columns ,
91
+ },
92
+ )
93
+ dataset = dataset .map (
94
+ _input_pipeline_utils .prepare_text_for_image_fusion ,
95
+ fn_kwargs = {"column_name" : text_columns [0 ], "model_name" : config .model_name },
96
+ )
97
+
98
+ dataset = _input_pipeline_utils .HFDataSource (
99
+ dataset = dataset ,
100
+ dataloading_host_index = dataloading_host_index ,
101
+ dataloading_host_count = dataloading_host_count ,
102
+ num_threads = 1 ,
103
+ generate_padding_example = True ,
104
+ max_target_length = config .max_target_length ,
105
+ data_column_names = text_columns ,
106
+ )
107
+ operations = []
108
+ operations .append (
109
+ _input_pipeline_utils .SFTPromptMaskingVision (
110
+ query_column = text_columns [0 ],
111
+ response_column = text_columns [1 ],
112
+ max_target_length = config .max_target_length ,
113
+ unk_id = pad_id ,
114
+ )
115
+ )
116
+ # TODO(aireenmei, hengtaoguo): support packing
117
+ operations .append (_input_pipeline_utils .PadToMaxLength (config .max_target_length , pad_id ))
118
+ operations .append (grain .Batch (batch_size = global_batch_size // jax .process_count (), drop_remainder = True ))
119
+ operations .append (_input_pipeline_utils .ShiftData (ignored_ids = [pad_id ], axis = 1 ))
120
+ dummy_index_sampler = grain .IndexSampler (
121
+ num_records = len (dataset ),
122
+ num_epochs = 1 ,
123
+ shard_options = grain .ShardOptions (
124
+ shard_index = dataloading_host_index , shard_count = dataloading_host_count , drop_remainder = False
125
+ ),
126
+ shuffle = False ,
127
+ seed = 0 ,
128
+ )
129
+
130
+ dataloader = grain .DataLoader (
131
+ data_source = dataset ,
132
+ operations = operations ,
133
+ sampler = dummy_index_sampler ,
134
+ worker_count = 1 , # only supports <=1 for now, more workers results in duplicated data
135
+ worker_buffer_size = 1 ,
136
+ read_options = grain .ReadOptions (num_threads = 1 , prefetch_buffer_size = 128 ),
137
+ )
138
+
139
+ multihost_gen = multihost_dataloading .MultiHostDataLoadIterator (dataloader , global_mesh )
140
+
141
+ # Return multi-host jax.Array prep iterator
142
+ return multihost_gen
143
+
144
+
35
145
def preprocessing_pipeline (
36
146
dataloading_host_index ,
37
147
dataloading_host_count ,
@@ -212,27 +322,39 @@ def make_hf_train_iterator(
212
322
streaming = True ,
213
323
token = config .hf_access_token ,
214
324
)
215
- train_iter = preprocessing_pipeline (
216
- dataloading_host_index = process_indices_train .index (jax .process_index ()),
217
- dataloading_host_count = len (process_indices_train ),
218
- global_mesh = global_mesh ,
219
- dataset = train_ds ,
220
- data_column_names = config .train_data_columns ,
221
- tokenize = config .tokenize_train_data ,
222
- tokenizer_path = config .tokenizer_path ,
223
- hf_access_token = config .hf_access_token ,
224
- global_batch_size = config .global_batch_size_to_load ,
225
- max_target_length = config .max_target_length ,
226
- shuffle = config .enable_data_shuffling ,
227
- data_shuffle_seed = config .data_shuffle_seed ,
228
- add_bos = config .add_bos ,
229
- add_eos = config .add_eos ,
230
- packing = config .packing ,
231
- generate_padding_example = False ,
232
- use_dpo = config .use_dpo ,
233
- use_sft = config .use_sft ,
234
- sft_train_on_completion_only = config .sft_train_on_completion_only ,
235
- )
325
+ if config .use_sft and config .use_multimodal :
326
+ train_iter = vision_sft_preprocessing_pipeline (
327
+ dataset = train_ds ,
328
+ config = config ,
329
+ dataloading_host_index = process_indices_train .index (jax .process_index ()),
330
+ dataloading_host_count = len (process_indices_train ),
331
+ global_mesh = global_mesh ,
332
+ text_columns = config .train_data_columns ,
333
+ image_column = config .train_image_column ,
334
+ global_batch_size = config .global_batch_size_to_load ,
335
+ )
336
+ else :
337
+ train_iter = preprocessing_pipeline (
338
+ dataloading_host_index = process_indices_train .index (jax .process_index ()),
339
+ dataloading_host_count = len (process_indices_train ),
340
+ global_mesh = global_mesh ,
341
+ dataset = train_ds ,
342
+ data_column_names = config .train_data_columns ,
343
+ tokenize = config .tokenize_train_data ,
344
+ tokenizer_path = config .tokenizer_path ,
345
+ hf_access_token = config .hf_access_token ,
346
+ global_batch_size = config .global_batch_size_to_load ,
347
+ max_target_length = config .max_target_length ,
348
+ shuffle = config .enable_data_shuffling ,
349
+ data_shuffle_seed = config .data_shuffle_seed ,
350
+ add_bos = config .add_bos ,
351
+ add_eos = config .add_eos ,
352
+ packing = config .packing ,
353
+ generate_padding_example = False ,
354
+ use_dpo = config .use_dpo ,
355
+ use_sft = config .use_sft ,
356
+ sft_train_on_completion_only = config .sft_train_on_completion_only ,
357
+ )
236
358
return train_iter
237
359
238
360
@@ -252,25 +374,37 @@ def make_hf_eval_iterator(
252
374
)
253
375
254
376
eval_generate_padding_example = config .eval_steps > 0
255
- eval_iter = preprocessing_pipeline (
256
- dataloading_host_index = process_indices_eval .index (jax .process_index ()),
257
- dataloading_host_count = len (process_indices_eval ),
258
- global_mesh = global_mesh ,
259
- dataset = eval_ds ,
260
- data_column_names = config .eval_data_columns ,
261
- tokenize = config .tokenize_eval_data ,
262
- tokenizer_path = config .tokenizer_path ,
263
- hf_access_token = config .hf_access_token ,
264
- global_batch_size = config .global_batch_size_to_load_eval ,
265
- max_target_length = config .max_target_length ,
266
- shuffle = False ,
267
- data_shuffle_seed = config .data_shuffle_seed ,
268
- add_bos = config .add_bos ,
269
- add_eos = config .add_eos ,
270
- packing = config .packing ,
271
- generate_padding_example = eval_generate_padding_example ,
272
- use_dpo = config .use_dpo ,
273
- use_sft = config .use_sft ,
274
- sft_train_on_completion_only = config .sft_train_on_completion_only ,
275
- )
377
+ if config .use_sft and config .use_multimodal :
378
+ eval_iter = vision_sft_preprocessing_pipeline (
379
+ dataset = eval_ds ,
380
+ config = config ,
381
+ dataloading_host_index = process_indices_eval .index (jax .process_index ()),
382
+ dataloading_host_count = len (process_indices_eval ),
383
+ global_mesh = global_mesh ,
384
+ text_columns = config .eval_data_columns ,
385
+ image_column = config .eval_image_column ,
386
+ global_batch_size = config .global_batch_size_to_load_eval ,
387
+ )
388
+ else :
389
+ eval_iter = preprocessing_pipeline (
390
+ dataloading_host_index = process_indices_eval .index (jax .process_index ()),
391
+ dataloading_host_count = len (process_indices_eval ),
392
+ global_mesh = global_mesh ,
393
+ dataset = eval_ds ,
394
+ data_column_names = config .eval_data_columns ,
395
+ tokenize = config .tokenize_eval_data ,
396
+ tokenizer_path = config .tokenizer_path ,
397
+ hf_access_token = config .hf_access_token ,
398
+ global_batch_size = config .global_batch_size_to_load_eval ,
399
+ max_target_length = config .max_target_length ,
400
+ shuffle = False ,
401
+ data_shuffle_seed = config .data_shuffle_seed ,
402
+ add_bos = config .add_bos ,
403
+ add_eos = config .add_eos ,
404
+ packing = config .packing ,
405
+ generate_padding_example = eval_generate_padding_example ,
406
+ use_dpo = config .use_dpo ,
407
+ use_sft = config .use_sft ,
408
+ sft_train_on_completion_only = config .sft_train_on_completion_only ,
409
+ )
276
410
return eval_iter
0 commit comments