Skip to content

Commit 1e618b0

Browse files
authored
Merge pull request #26 from zwkkk/master
添加FLIP训练功能
2 parents 1924b1b + 0b6d4c2 commit 1e618b0

File tree

6 files changed

+203
-9
lines changed

6 files changed

+203
-9
lines changed

cn_clip/clip/model.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,30 @@ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: i
240240
def set_grad_checkpointing(self, enable=True):
241241
self.transformer.grad_checkpointing = enable
242242

243-
def forward(self, x: torch.Tensor):
243+
def random_masking(self, x, mask_ratio):
244+
N, L, D = x.shape # batch, length, dim
245+
len_keep = int((L - 1) * (1 - mask_ratio))
246+
247+
noise = torch.rand(N, L - 1, device=x.device)
248+
ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device,
249+
dtype=int)
250+
ids_keep = ids_shuffle[:, :len_keep]
251+
252+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
253+
254+
x0 = x[:, 0, :]
255+
x0 = x0.reshape(N, 1, D)
256+
x_masked_add = torch.cat([x0, x_masked], axis=1)
257+
return x_masked_add
258+
259+
def forward(self, x: torch.Tensor, mask_ratio: float = 0.0):
244260
x = self.conv1(x) # shape = [*, width, grid, grid]
245261
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
246262
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
247263
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
248264
x = x + self.positional_embedding.to(x.dtype)
265+
if mask_ratio != 0:
266+
x = self.random_masking(x, mask_ratio)
249267
x = self.ln_pre(x)
250268

251269
x = x.permute(1, 0, 2) # NLD -> LND
@@ -282,7 +300,7 @@ def __init__(self,
282300
text_type_vocab_size: int,
283301
tokenizer = _tokenizer,
284302
# vision head width, added this param for ViT-H
285-
vision_head_width: int = 64,
303+
vision_head_width: int = 64,
286304
):
287305
super().__init__()
288306

@@ -357,23 +375,23 @@ def set_grad_checkpointing(self, enable=True):
357375
def dtype(self):
358376
return self.visual.conv1.weight.dtype
359377

360-
def encode_image(self, image):
361-
return self.visual(image.type(self.dtype))
378+
def encode_image(self, image, mask_ratio=0):
379+
return self.visual(image.type(self.dtype), mask_ratio)
362380

363381
def encode_text(self, text):
364382
pad_index = self.tokenizer.vocab['[PAD]']
365383
attn_mask = text.ne(pad_index).type(self.dtype)
366384
x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
367385
return x[:, 0, :] @ self.text_projection
368386

369-
def forward(self, image, text):
387+
def forward(self, image, text, mask_ratio=0):
370388
assert image is not None or text is not None, "text and image cannot both be None!"
371389

372390
if image is None:
373391
return self.encode_text(text)
374392
elif text is None:
375393
return self.encode_image(image)
376-
image_features = self.encode_image(image)
394+
image_features = self.encode_image(image, mask_ratio)
377395
text_features = self.encode_text(text)
378396

379397
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
@@ -500,4 +518,4 @@ def parse(x):
500518
to_2tuple = _ntuple(2)
501519
to_3tuple = _ntuple(3)
502520
to_4tuple = _ntuple(4)
503-
to_ntuple = lambda n, x: _ntuple(n)(x)
521+
to_ntuple = lambda n, x: _ntuple(n)(x)

cn_clip/training/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main():
8888
model_info['vision_layers'] = eval(model_info['vision_layers'])
8989
for k, v in json.load(ft).items():
9090
model_info[k] = v
91-
91+
9292
model = CLIP(**model_info)
9393
if args.clip_weight_path is not None:
9494
assert os.path.exists(args.clip_weight_path), "Pretrained CLIP weight not exists!"

cn_clip/training/params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ def parse_args():
129129
default="ViT-B-16",
130130
help="Name of the vision backbone to use.",
131131
)
132+
parser.add_argument(
133+
"--mask_ratio",
134+
default=0,
135+
type=float,
136+
help="mask ratio of patches.",
137+
)
132138
parser.add_argument(
133139
"--clip-weight-path",
134140
default=None,

cn_clip/training/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def is_master(args):
1616
return args.rank == 0
1717

1818
def get_loss(model, images, texts, loss_img, loss_txt, args):
19-
image_features, text_features, logit_scale = model(images, texts)
19+
image_features, text_features, logit_scale = model(images, texts, args.mask_ratio)
2020
logit_scale = logit_scale.mean()
2121
if args.aggregate:
2222
world_size = dist.get_world_size()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env
2+
3+
# Guide:
4+
# This script supports distributed training on multi-gpu workers (as well as single-worker training).
5+
# Please set the options below according to the comments.
6+
# For multi-gpu workers training, these options should be manually set for each worker.
7+
# After setting the options, please run the script on each worker.
8+
# Command: bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}
9+
10+
# Number of GPUs per GPU worker
11+
GPUS_PER_NODE=8
12+
# Number of GPU workers, for single-worker training, please set to 1
13+
WORKER_CNT=1
14+
# The ip address of the rank-0 worker, for single-worker training, please set to localhost
15+
export MASTER_ADDR=XX.XX.XX.XX
16+
# The port for communication
17+
export MASTER_PORT=8514
18+
# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
19+
export RANK=0
20+
21+
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip/
22+
23+
DATAPATH=${1}
24+
25+
# data options
26+
train_data=${DATAPATH}/datasets/Flickr30k-CN/lmdb/train
27+
val_data=${DATAPATH}/datasets/Flickr30k-CN/lmdb/valid # if val_data is not specified, the validation will be automatically disabled
28+
29+
# restore options
30+
resume=${DATAPATH}/pretrained_weights/clip_cn_vit-b-16.pt # or specify your customed ckpt path to resume
31+
reset_data_offset="--reset-data-offset"
32+
reset_optimizer="--reset-optimizer"
33+
# reset_optimizer=""
34+
35+
# output options
36+
output_base_dir=${DATAPATH}/experiments/
37+
name=flickr30k_finetune_vit-b-16_roberta-base_bs128_8gpu
38+
save_step_frequency=999999 # disable it
39+
save_epoch_frequency=1
40+
log_interval=1
41+
report_training_batch_acc="--report-training-batch-acc"
42+
# report_training_batch_acc=""
43+
44+
# training hyper-params
45+
context_length=52
46+
warmup=100
47+
batch_size=128
48+
valid_batch_size=128
49+
lr=5e-5
50+
wd=0.001
51+
max_epochs=3 # or specify your customed ckpt path to resume
52+
valid_step_interval=150
53+
valid_epoch_interval=1
54+
vision_model=ViT-B-16
55+
text_model=RoBERTa-wwm-ext-base-chinese
56+
mask_ratio=0.5 # use flip: set mask ratio
57+
use_augment="--use-augment"
58+
# use_augment=""
59+
60+
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
61+
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
62+
--train-data=${train_data} \
63+
--val-data=${val_data} \
64+
--resume=${resume} \
65+
${reset_data_offset} \
66+
${reset_optimizer} \
67+
--logs=${output_base_dir} \
68+
--name=${name} \
69+
--save-step-frequency=${save_step_frequency} \
70+
--save-epoch-frequency=${save_epoch_frequency} \
71+
--log-interval=${log_interval} \
72+
${report_training_batch_acc} \
73+
--context-length=${context_length} \
74+
--warmup=${warmup} \
75+
--batch-size=${batch_size} \
76+
--valid-batch-size=${valid_batch_size} \
77+
--valid-step-interval=${valid_step_interval} \
78+
--valid-epoch-interval=${valid_epoch_interval} \
79+
--lr=${lr} \
80+
--wd=${wd} \
81+
--max-epochs=${max_epochs} \
82+
--vision-model=${vision_model} \
83+
--mask_ratio=${mask_ratio} \
84+
${use_augment} \
85+
--text-model=${text_model}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env
2+
3+
# Guide:
4+
# This script supports distributed training on multi-gpu workers (as well as single-worker training).
5+
# Please set the options below according to the comments.
6+
# For multi-gpu workers training, these options should be manually set for each worker.
7+
# After setting the options, please run the script on each worker.
8+
# Command: bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}
9+
10+
# Number of GPUs per GPU worker
11+
GPUS_PER_NODE=8
12+
# Number of GPU workers, for single-worker training, please set to 1
13+
WORKER_CNT=1
14+
# The ip address of the rank-0 worker, for single-worker training, please set to localhost
15+
export MASTER_ADDR=XX.XX.XX.XX
16+
# The port for communication
17+
export MASTER_PORT=8514
18+
# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
19+
export RANK=0
20+
21+
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip/
22+
23+
DATAPATH=${1}
24+
25+
# data options
26+
train_data=${DATAPATH}/datasets/MUGE/lmdb/train
27+
val_data=${DATAPATH}/datasets/MUGE/lmdb/valid # if val_data is not specified, the validation will be automatically disabled
28+
29+
# restore options
30+
resume=${DATAPATH}/pretrained_weights/clip_cn_vit-b-16.pt # or specify your customed ckpt path to resume
31+
reset_data_offset="--reset-data-offset"
32+
reset_optimizer="--reset-optimizer"
33+
# reset_optimizer=""
34+
35+
# output options
36+
output_base_dir=${DATAPATH}/experiments/
37+
name=muge_finetune_vit-b-16_roberta-base_bs128_8gpu
38+
save_step_frequency=999999 # disable it
39+
save_epoch_frequency=1
40+
log_interval=1
41+
report_training_batch_acc="--report-training-batch-acc"
42+
# report_training_batch_acc=""
43+
44+
# training hyper-params
45+
context_length=52
46+
warmup=100
47+
batch_size=128
48+
valid_batch_size=128
49+
lr=5e-5
50+
wd=0.001
51+
max_epochs=3 # or specify your customed ckpt path to resume
52+
valid_step_interval=150
53+
valid_epoch_interval=1
54+
vision_model=ViT-B-16
55+
text_model=RoBERTa-wwm-ext-base-chinese
56+
mask_ratio=0.5 # use flip: set mask ratio
57+
use_augment="--use-augment"
58+
# use_augment=""
59+
60+
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
61+
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
62+
--train-data=${train_data} \
63+
--val-data=${val_data} \
64+
--resume=${resume} \
65+
${reset_data_offset} \
66+
${reset_optimizer} \
67+
--logs=${output_base_dir} \
68+
--name=${name} \
69+
--save-step-frequency=${save_step_frequency} \
70+
--save-epoch-frequency=${save_epoch_frequency} \
71+
--log-interval=${log_interval} \
72+
${report_training_batch_acc} \
73+
--context-length=${context_length} \
74+
--warmup=${warmup} \
75+
--batch-size=${batch_size} \
76+
--valid-batch-size=${valid_batch_size} \
77+
--valid-step-interval=${valid_step_interval} \
78+
--valid-epoch-interval=${valid_epoch_interval} \
79+
--lr=${lr} \
80+
--wd=${wd} \
81+
--max-epochs=${max_epochs} \
82+
--vision-model=${vision_model} \
83+
--mask_ratio=${mask_ratio} \
84+
${use_augment} \
85+
--text-model=${text_model}

0 commit comments

Comments
 (0)