Skip to content

Commit 846e630

Browse files
authored
Merge pull request #195 from Jaskr616/master
Add distillation for fine-tuning
2 parents 2c38d03 + 1fd71f7 commit 846e630

File tree

9 files changed

+359
-6
lines changed

9 files changed

+359
-6
lines changed

cn_clip/training/main.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,50 @@ def main():
243243
# only do so if it is the 0th worker.
244244
args.should_save = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and is_master(args)
245245

246+
# load teacher model to distllation
247+
if args.distllation:
248+
try:
249+
from modelscope.models import Model
250+
except:
251+
raise ImportError("modelscope is not installed. Please install it by `pip install modelscope`.")
252+
253+
teacher_model_dict = {
254+
"damo/multi-modal_team-vit-large-patch14_multi-modal-similarity" : {"model": "image_model"},
255+
"damo/multi-modal_rleg-vit-large-patch14" : {"model": "encode_image"},
256+
"damo/multi-modal_clip-vit-huge-patch14_zh" : {"clip_model": "encode_image"},
257+
"damo/multi-modal_clip-vit-large-patch14_zh" : {"clip_model": "encode_image"},
258+
}
259+
assert args.teacher_model_name in teacher_model_dict, "Error: Valid teacher model name has not been built."
260+
261+
teacher_model = Model.from_pretrained(args.teacher_model_name)
262+
for k, v in teacher_model.state_dict().items():
263+
v.requires_grad = False
264+
265+
# mapping different extract_features function to same name
266+
mapping = teacher_model_dict[args.teacher_model_name]
267+
if "model" in mapping and hasattr(teacher_model, "model"):
268+
model_instance = getattr(teacher_model, "model")
269+
if hasattr(model_instance, mapping["model"]):
270+
setattr(teacher_model, "get_feature", getattr(model_instance, mapping["model"]))
271+
elif "clip_model" in mapping and hasattr(teacher_model, "clip_model"):
272+
model_instance = getattr(teacher_model, "clip_model")
273+
if hasattr(model_instance, mapping["clip_model"]):
274+
setattr(teacher_model, "get_feature", getattr(model_instance, mapping["clip_model"]))
275+
276+
teacher_model.cuda(args.local_device_rank)
277+
teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model, device_ids=[args.local_device_rank])
278+
logging.info(f"Teacher model loaded from {args.teacher_model_name}")
279+
else:
280+
teacher_model = None
281+
282+
246283
for epoch in range(start_epoch, args.max_epochs):
247284
if is_master(args) == 0:
248285
logging.info(f'Start epoch {epoch + 1}')
249-
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps)
286+
if args.distllation:
287+
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps, teacher_model)
288+
else:
289+
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps)
250290
steps += num_steps_this_epoch
251291

252292
if args.val_data is not None and args.valid_epoch_interval is not None and ((epoch + 1) % args.valid_epoch_interval) == 0:

cn_clip/training/params.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,25 @@ def parse_args():
205205
default=123,
206206
help="Random seed."
207207
)
208+
# arguments for distllation
209+
parser.add_argument(
210+
"--distllation",
211+
default=False,
212+
action="store_true",
213+
help="If true, more information is logged."
214+
)
215+
parser.add_argument(
216+
"--teacher-model-name",
217+
type=str,
218+
default=None,
219+
help="The name of teacher model."
220+
)
221+
parser.add_argument(
222+
"--kd_loss_weight",
223+
type=float,
224+
default=0.5,
225+
help="Weight of KD loss."
226+
)
208227
args = parser.parse_args()
209228
args.aggregate = not args.skip_aggregate
210229

cn_clip/training/train.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,41 @@
1010
from torch.cuda.amp import autocast
1111
import torch.distributed.nn
1212
import torch.distributed as dist
13+
import torch.nn.functional as F
1314

1415
from cn_clip.clip.model import convert_state_dict
1516

1617

1718
def is_master(args):
1819
return args.rank == 0
1920

20-
def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features=None, accum_text_features=None, accum_idx=-1):
21+
def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features=None, accum_text_features=None, accum_idx=-1, teacher_model=None, teacher_accum_image_features=None):
2122
if args.accum_freq == 1:
2223
image_features, text_features, logit_scale = model(images, texts, args.mask_ratio)
24+
25+
if args.distllation:
26+
with torch.no_grad():
27+
# different teacher model has different output
28+
output = teacher_model.module.get_feature(images)
29+
if(isinstance(output, tuple)):
30+
teacher_image_features = output[0]
31+
else:
32+
teacher_image_features = output
2333
else:
2434
assert accum_image_features and accum_text_features and accum_idx != -1
2535
chunk_image_features, chunk_text_features, logit_scale = model(images, texts, args.mask_ratio)
36+
37+
if args.distllation:
38+
with torch.no_grad():
39+
# different teacher model has different output
40+
output = teacher_model.module.get_feature(images)
41+
if(isinstance(output, tuple)):
42+
teacher_chunk_image_features = output[0]
43+
else:
44+
teacher_chunk_image_features = output
45+
teacher_image_features = torch.cat(
46+
teacher_accum_image_features[:accum_idx] + [teacher_chunk_image_features] + teacher_accum_image_features[accum_idx + 1:])
47+
2648
image_features = torch.cat(
2749
accum_image_features[:accum_idx] + [chunk_image_features] + accum_image_features[accum_idx + 1:])
2850
text_features = torch.cat(
@@ -36,13 +58,17 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
3658
if args.gather_with_grad:
3759
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
3860
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
61+
62+
if args.distllation:
63+
all_teacher_image_features = torch.cat(torch.distributed.nn.all_gather(teacher_image_features), dim=0)
3964
else:
4065
gathered_image_features = [
4166
torch.zeros_like(image_features) for _ in range(world_size)
4267
]
4368
gathered_text_features = [
4469
torch.zeros_like(text_features) for _ in range(world_size)
4570
]
71+
4672
dist.all_gather(gathered_image_features, image_features)
4773
dist.all_gather(gathered_text_features, text_features)
4874

@@ -61,10 +87,25 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
6187
logits_per_image = logit_scale * all_image_features @ all_text_features.t()
6288
logits_per_text = logits_per_image.t()
6389

90+
if args.distllation:
91+
gathered_teacher_image_features = [
92+
torch.zeros_like(teacher_image_features) for _ in range(world_size)
93+
]
94+
dist.all_gather(gathered_teacher_image_features, teacher_image_features)
95+
all_teacher_image_features = torch.cat(
96+
[teacher_image_features]
97+
+ gathered_teacher_image_features[:rank]
98+
+ gathered_teacher_image_features[rank + 1 :]
99+
)
100+
kd_loss = cosineSimilarityLoss(all_teacher_image_features, all_image_features)
101+
64102
else:
65103
logits_per_image = logit_scale * image_features @ text_features.t()
66104
logits_per_text = logit_scale * text_features @ image_features.t()
67105

106+
if args.distllation:
107+
kd_loss = cosineSimilarityLoss(teacher_image_features, image_features)
108+
68109
ground_truth = torch.arange(len(logits_per_image)).long()
69110
ground_truth = ground_truth.cuda(args.local_device_rank, non_blocking=True)
70111

@@ -79,6 +120,9 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
79120
t2i_acc = (logits_per_text.argmax(-1) == ground_truth).sum() / len(logits_per_text)
80121
acc = {"i2t": i2t_acc, "t2i": t2i_acc}
81122

123+
if args.distllation:
124+
total_loss += kd_loss * args.kd_loss_weight
125+
82126
return total_loss, acc
83127

84128
def freeze_vision_bn(args, model):
@@ -89,7 +133,7 @@ def freeze_vision_bn(args, model):
89133
if isinstance(m, nn.BatchNorm2d):
90134
m.eval()
91135

92-
def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained_steps):
136+
def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained_steps, teacher_model=None):
93137
# os.environ["WDS_EPOCH"] = str(epoch)
94138

95139
model.train()
@@ -112,6 +156,8 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
112156

113157
if args.accum_freq > 1:
114158
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
159+
if args.distllation:
160+
teacher_accum_image_features = []
115161

116162
end = time.time()
117163
epoch_trained_steps = 0
@@ -142,22 +188,36 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
142188
# with automatic mixed precision.
143189
if args.precision == "amp":
144190
with autocast():
145-
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
191+
if args.distllation:
192+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, teacher_model=teacher_model)
193+
else:
194+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
146195
scaler.scale(total_loss).backward()
147196
scaler.step(optimizer)
148197
scaler.update()
149198

150199
else:
151-
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
200+
if args.distllation:
201+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, teacher_model=teacher_model)
202+
else:
203+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
152204
total_loss.backward()
153205
optimizer.step()
154206
else:
155207
# First, cache the features without any gradient tracking.
156208
with torch.no_grad():
157209
with autocast(enabled=(args.precision == "amp")):
158210
chunk_image_features, chunk_text_features, _ = model(images, texts)
211+
if args.distllation:
212+
output = teacher_model.module.get_feature(images)
213+
if(len(output) == 2):
214+
teacher_chunk_image_features = output[0]
215+
else:
216+
teacher_chunk_image_features = output
159217
accum_image_features.append(chunk_image_features)
160218
accum_text_features.append(chunk_text_features)
219+
if args.distllation:
220+
teacher_accum_image_features.append(teacher_chunk_image_features)
161221

162222
accum_images.append(images)
163223
accum_texts.append(texts)
@@ -177,7 +237,10 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
177237
with autocast(enabled=(args.precision == "amp")):
178238
# `total_loss` and `acc` are coarsely sampled, taking only the last result in the loop.
179239
# Although each result should be the same in theory, it will be slightly different in practice
180-
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j)
240+
if args.distllation:
241+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j, teacher_model, teacher_accum_image_features)
242+
else:
243+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j)
181244
if args.precision == "amp":
182245
scaler.scale(total_loss).backward()
183246
else:
@@ -192,6 +255,8 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
192255
# reset gradient accum, if enabled
193256
if args.accum_freq > 1:
194257
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
258+
if args.distllation:
259+
teacher_accum_image_features = []
195260

196261
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
197262
m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)
@@ -337,3 +402,18 @@ def evaluate(model, data, epoch, args, steps):
337402
f"logit_scale: {model.module.logit_scale.data:.3f} | "
338403
f"Valid Batch Size: {batch_size}"
339404
)
405+
406+
def cosineSimilarityLoss(feature1, feature2):
407+
scale_factor_h = feature1.shape[0] / feature2.size(0)
408+
scale_factor_w = feature1.shape[1] / feature2.size(1)
409+
410+
feature2_interpolated = F.interpolate(feature2.unsqueeze(0).unsqueeze(0),
411+
size=(feature1.shape[0], feature1.shape[1]),
412+
mode='bilinear',
413+
align_corners=False)
414+
feature2_interpolated = feature2_interpolated.squeeze(0).squeeze(0)
415+
416+
417+
cosine_sim = F.cosine_similarity(feature1, feature2_interpolated, dim=1)
418+
similarity_loss = 1 - cosine_sim.mean()
419+
return similarity_loss

distillation.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
[**中文说明**](distillation.md) | [**English**](distillation_En.md)
2+
3+
# 使用知识蒸馏提升Chinese-CLIP图像检索能力
4+
5+
Chinese-CLIP结合知识蒸馏进行微调训练,进一步提升ChineseClip的图像检索(image2image)能力。使用的Teacher model全都来自[ModelScope](https://github.com/modelscope/modelscope)
6+
7+
## 环境准备
8+
9+
+ **Turing****Ampere****Ada****Hopper**架构的Nvidia GPU显卡(如H100、A100、RTX 3090、T4、RTX 2080),Nvidia各架构对应显卡型号可参见[此文档表格](https://en.wikipedia.org/wiki/CUDA#GPUs_supported)
10+
+ CUDA 11.4及以上版本。
11+
+ Pytorch 1.12及以上版本。
12+
+ [requirements.txt](requirements.txt)要求的其他依赖项
13+
+ **ModelScope**:通过执行`pip install modelscope`安装ModelScope。
14+
15+
## 在Chinese-CLIP中用起来!
16+
17+
在Chinese-CLIP finetune中对于图像端应用知识蒸馏并不复杂。只需要在finetune的sh脚本中加入`--distllation`配置项。
18+
然后在配置项`--teacher-model-name`填入所要使用的Teacher model名称。现在支持的Teacher mode包括以下四种。
19+
<table border="1" width="120%">
20+
<tr align="center">
21+
<td><b>Teacher model</b></td><td><b>模型介绍</b></td>
22+
</tr>
23+
<tr align="center">
24+
<td>damo/multi-modal_team-vit-large-patch14_multi-modal-similarity</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_team-vit-large-patch14_multi-modal-similarity/summary">TEAM图文检索模型-中文-large</a></td>
25+
</tr>
26+
<tr align="center">
27+
<td>damo/multi-modal_rleg-vit-large-patch14</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_rleg-vit-large-patch14/summary">RLEG生成式多模态表征模型-英文-large
28+
</a></td>
29+
</tr>
30+
<tr align="center">
31+
<td>damo/multi-modal_clip-vit-huge-patch14_zh</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_clip-vit-huge-patch14_zh/summary">CLIP模型-中文-通用领域-huge</a></td>
32+
</tr>
33+
<tr align="center">
34+
<td>damo/multi-modal_clip-vit-large-patch14_zh</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_clip-vit-large-patch14_zh/summary">CLIP模型-中文-通用领域-large</a></td>
35+
</tr>
36+
</table>
37+
<br>
38+
39+
最后在配置项`--kd_loss_weight`填入蒸馏损失的权值,默认值是0.5。
40+
41+
42+
其中各配置项定义如下:
43+
+ `distllation`: 是否启用知识蒸馏微调模型图像端。
44+
+ `teacher-model-name`: 指定使用的Teacher model。目前支持以上四个Teacher model,如填入`damo/multi-modal_team-vit-large-patch14_multi-modal-similarity`
45+
+ `kd_loss_weight`(可选): 蒸馏损失的权值,默认值是0.5。
46+
47+
我们提供了样例脚本`run_scripts/muge_finetune_vit-b-16_rbt-base_distllation.sh`,使用的是`TEAM图文检索模型-中文-large`作为Teacher model。
48+
49+
## 效果验证
50+
这里是我们模型(finetune+distillation)/预训练模型/finetune模型的图像检索Top10结果。左上角图像作为query,右边按顺序Top1到Top10检索结果。本次实验的support数据集有10万电商数据量(包括鞋子、衣服、裤子等物品)。
51+
52+
我们方法的优势:
53+
+ 符合检索任务基本要求:在保证了类目相似性的前提下,很好实现了图像相似性。
54+
+ 性能好且速度快:通过蒸馏的方法,使得base模型有着large模型类似的检索效果。并且部署到CPU,检索推理时间控制在了100ms以内。
55+
56+
<p style="text-align: center;">
57+
<img src="examples/image_retrieval_result1.jpg" width="400" /><br>
58+
<img src="examples/image_retrieval_result3.jpg" width="400" /><br>
59+
<img src="examples/image_retrieval_result2.jpg" width="400" /><br>
60+
</p>
61+
62+
63+
## Todo
64+
将会在阿里云官网上线相关的解决方案的Jupyter Notebook。

0 commit comments

Comments
 (0)