Skip to content

Add distillation for fine-tuning #195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion cn_clip/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,50 @@ def main():
# only do so if it is the 0th worker.
args.should_save = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and is_master(args)

# load teacher model to distllation
if args.distllation:
try:
from modelscope.models import Model
except:
raise ImportError("modelscope is not installed. Please install it by `pip install modelscope`.")

teacher_model_dict = {
"damo/multi-modal_team-vit-large-patch14_multi-modal-similarity" : {"model": "image_model"},
"damo/multi-modal_rleg-vit-large-patch14" : {"model": "encode_image"},
"damo/multi-modal_clip-vit-huge-patch14_zh" : {"clip_model": "encode_image"},
"damo/multi-modal_clip-vit-large-patch14_zh" : {"clip_model": "encode_image"},
}
assert args.teacher_model_name in teacher_model_dict, "Error: Valid teacher model name has not been built."

teacher_model = Model.from_pretrained(args.teacher_model_name)
for k, v in teacher_model.state_dict().items():
v.requires_grad = False

# mapping different extract_features function to same name
mapping = teacher_model_dict[args.teacher_model_name]
if "model" in mapping and hasattr(teacher_model, "model"):
model_instance = getattr(teacher_model, "model")
if hasattr(model_instance, mapping["model"]):
setattr(teacher_model, "get_feature", getattr(model_instance, mapping["model"]))
elif "clip_model" in mapping and hasattr(teacher_model, "clip_model"):
model_instance = getattr(teacher_model, "clip_model")
if hasattr(model_instance, mapping["clip_model"]):
setattr(teacher_model, "get_feature", getattr(model_instance, mapping["clip_model"]))

teacher_model.cuda(args.local_device_rank)
teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model, device_ids=[args.local_device_rank])
logging.info(f"Teacher model loaded from {args.teacher_model_name}")
else:
teacher_model = None


for epoch in range(start_epoch, args.max_epochs):
if is_master(args) == 0:
logging.info(f'Start epoch {epoch + 1}')
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps)
if args.distllation:
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps, teacher_model)
else:
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps)
steps += num_steps_this_epoch

if args.val_data is not None and args.valid_epoch_interval is not None and ((epoch + 1) % args.valid_epoch_interval) == 0:
Expand Down
19 changes: 19 additions & 0 deletions cn_clip/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,25 @@ def parse_args():
default=123,
help="Random seed."
)
# arguments for distllation
parser.add_argument(
"--distllation",
default=False,
action="store_true",
help="If true, more information is logged."
)
parser.add_argument(
"--teacher-model-name",
type=str,
default=None,
help="The name of teacher model."
)
parser.add_argument(
"--kd_loss_weight",
type=float,
default=0.5,
help="Weight of KD loss."
)
args = parser.parse_args()
args.aggregate = not args.skip_aggregate

Expand Down
90 changes: 85 additions & 5 deletions cn_clip/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,41 @@
from torch.cuda.amp import autocast
import torch.distributed.nn
import torch.distributed as dist
import torch.nn.functional as F

from cn_clip.clip.model import convert_state_dict


def is_master(args):
return args.rank == 0

def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features=None, accum_text_features=None, accum_idx=-1):
def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features=None, accum_text_features=None, accum_idx=-1, teacher_model=None, teacher_accum_image_features=None):
if args.accum_freq == 1:
image_features, text_features, logit_scale = model(images, texts, args.mask_ratio)

if args.distllation:
with torch.no_grad():
# different teacher model has different output
output = teacher_model.module.get_feature(images)
if(isinstance(output, tuple)):
teacher_image_features = output[0]
else:
teacher_image_features = output
else:
assert accum_image_features and accum_text_features and accum_idx != -1
chunk_image_features, chunk_text_features, logit_scale = model(images, texts, args.mask_ratio)

if args.distllation:
with torch.no_grad():
# different teacher model has different output
output = teacher_model.module.get_feature(images)
if(isinstance(output, tuple)):
teacher_chunk_image_features = output[0]
else:
teacher_chunk_image_features = output
teacher_image_features = torch.cat(
teacher_accum_image_features[:accum_idx] + [teacher_chunk_image_features] + teacher_accum_image_features[accum_idx + 1:])

image_features = torch.cat(
accum_image_features[:accum_idx] + [chunk_image_features] + accum_image_features[accum_idx + 1:])
text_features = torch.cat(
Expand All @@ -36,13 +58,17 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
if args.gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)

if args.distllation:
all_teacher_image_features = torch.cat(torch.distributed.nn.all_gather(teacher_image_features), dim=0)
else:
gathered_image_features = [
torch.zeros_like(image_features) for _ in range(world_size)
]
gathered_text_features = [
torch.zeros_like(text_features) for _ in range(world_size)
]

dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)

Expand All @@ -61,10 +87,25 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
logits_per_image = logit_scale * all_image_features @ all_text_features.t()
logits_per_text = logits_per_image.t()

if args.distllation:
gathered_teacher_image_features = [
torch.zeros_like(teacher_image_features) for _ in range(world_size)
]
dist.all_gather(gathered_teacher_image_features, teacher_image_features)
all_teacher_image_features = torch.cat(
[teacher_image_features]
+ gathered_teacher_image_features[:rank]
+ gathered_teacher_image_features[rank + 1 :]
)
kd_loss = cosineSimilarityLoss(all_teacher_image_features, all_image_features)

else:
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()

if args.distllation:
kd_loss = cosineSimilarityLoss(teacher_image_features, image_features)

ground_truth = torch.arange(len(logits_per_image)).long()
ground_truth = ground_truth.cuda(args.local_device_rank, non_blocking=True)

Expand All @@ -79,6 +120,9 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
t2i_acc = (logits_per_text.argmax(-1) == ground_truth).sum() / len(logits_per_text)
acc = {"i2t": i2t_acc, "t2i": t2i_acc}

if args.distllation:
total_loss += kd_loss * args.kd_loss_weight

return total_loss, acc

def freeze_vision_bn(args, model):
Expand All @@ -89,7 +133,7 @@ def freeze_vision_bn(args, model):
if isinstance(m, nn.BatchNorm2d):
m.eval()

def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained_steps):
def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained_steps, teacher_model=None):
# os.environ["WDS_EPOCH"] = str(epoch)

model.train()
Expand All @@ -112,6 +156,8 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained

if args.accum_freq > 1:
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
if args.distllation:
teacher_accum_image_features = []

end = time.time()
epoch_trained_steps = 0
Expand Down Expand Up @@ -142,22 +188,36 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
# with automatic mixed precision.
if args.precision == "amp":
with autocast():
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
if args.distllation:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, teacher_model=teacher_model)
else:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
scaler.scale(total_loss).backward()
scaler.step(optimizer)
scaler.update()

else:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
if args.distllation:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, teacher_model=teacher_model)
else:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
total_loss.backward()
optimizer.step()
else:
# First, cache the features without any gradient tracking.
with torch.no_grad():
with autocast(enabled=(args.precision == "amp")):
chunk_image_features, chunk_text_features, _ = model(images, texts)
if args.distllation:
output = teacher_model.module.get_feature(images)
if(len(output) == 2):
teacher_chunk_image_features = output[0]
else:
teacher_chunk_image_features = output
accum_image_features.append(chunk_image_features)
accum_text_features.append(chunk_text_features)
if args.distllation:
teacher_accum_image_features.append(teacher_chunk_image_features)

accum_images.append(images)
accum_texts.append(texts)
Expand All @@ -177,7 +237,10 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
with autocast(enabled=(args.precision == "amp")):
# `total_loss` and `acc` are coarsely sampled, taking only the last result in the loop.
# Although each result should be the same in theory, it will be slightly different in practice
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j)
if args.distllation:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j, teacher_model, teacher_accum_image_features)
else:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j)
if args.precision == "amp":
scaler.scale(total_loss).backward()
else:
Expand All @@ -192,6 +255,8 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
# reset gradient accum, if enabled
if args.accum_freq > 1:
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
if args.distllation:
teacher_accum_image_features = []

# Note: we clamp to 4.6052 = ln(100), as in the original paper.
m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)
Expand Down Expand Up @@ -337,3 +402,18 @@ def evaluate(model, data, epoch, args, steps):
f"logit_scale: {model.module.logit_scale.data:.3f} | "
f"Valid Batch Size: {batch_size}"
)

def cosineSimilarityLoss(feature1, feature2):
scale_factor_h = feature1.shape[0] / feature2.size(0)
scale_factor_w = feature1.shape[1] / feature2.size(1)

feature2_interpolated = F.interpolate(feature2.unsqueeze(0).unsqueeze(0),
size=(feature1.shape[0], feature1.shape[1]),
mode='bilinear',
align_corners=False)
feature2_interpolated = feature2_interpolated.squeeze(0).squeeze(0)


cosine_sim = F.cosine_similarity(feature1, feature2_interpolated, dim=1)
similarity_loss = 1 - cosine_sim.mean()
return similarity_loss
64 changes: 64 additions & 0 deletions distillation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
[**中文说明**](distillation.md) | [**English**](distillation_En.md)

# 使用知识蒸馏提升Chinese-CLIP图像检索能力

Chinese-CLIP结合知识蒸馏进行微调训练,进一步提升ChineseClip的图像检索(image2image)能力。使用的Teacher model全都来自[ModelScope](https://github.com/modelscope/modelscope)。

## 环境准备

+ **Turing**、**Ampere**、**Ada**、**Hopper**架构的Nvidia GPU显卡(如H100、A100、RTX 3090、T4、RTX 2080),Nvidia各架构对应显卡型号可参见[此文档表格](https://en.wikipedia.org/wiki/CUDA#GPUs_supported)。
+ CUDA 11.4及以上版本。
+ Pytorch 1.12及以上版本。
+ [requirements.txt](requirements.txt)要求的其他依赖项
+ **ModelScope**:通过执行`pip install modelscope`安装ModelScope。

## 在Chinese-CLIP中用起来!

在Chinese-CLIP finetune中对于图像端应用知识蒸馏并不复杂。只需要在finetune的sh脚本中加入`--distllation`配置项。
然后在配置项`--teacher-model-name`填入所要使用的Teacher model名称。现在支持的Teacher mode包括以下四种。
<table border="1" width="120%">
<tr align="center">
<td><b>Teacher model</b></td><td><b>模型介绍</b></td>
</tr>
<tr align="center">
<td>damo/multi-modal_team-vit-large-patch14_multi-modal-similarity</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_team-vit-large-patch14_multi-modal-similarity/summary">TEAM图文检索模型-中文-large</a></td>
</tr>
<tr align="center">
<td>damo/multi-modal_rleg-vit-large-patch14</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_rleg-vit-large-patch14/summary">RLEG生成式多模态表征模型-英文-large
</a></td>
</tr>
<tr align="center">
<td>damo/multi-modal_clip-vit-huge-patch14_zh</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_clip-vit-huge-patch14_zh/summary">CLIP模型-中文-通用领域-huge</a></td>
</tr>
<tr align="center">
<td>damo/multi-modal_clip-vit-large-patch14_zh</td><td><a href="https://www.modelscope.cn/models/damo/multi-modal_clip-vit-large-patch14_zh/summary">CLIP模型-中文-通用领域-large</a></td>
</tr>
</table>
<br>

最后在配置项`--kd_loss_weight`填入蒸馏损失的权值,默认值是0.5。


其中各配置项定义如下:
+ `distllation`: 是否启用知识蒸馏微调模型图像端。
+ `teacher-model-name`: 指定使用的Teacher model。目前支持以上四个Teacher model,如填入`damo/multi-modal_team-vit-large-patch14_multi-modal-similarity`。
+ `kd_loss_weight`(可选): 蒸馏损失的权值,默认值是0.5。

我们提供了样例脚本`run_scripts/muge_finetune_vit-b-16_rbt-base_distllation.sh`,使用的是`TEAM图文检索模型-中文-large`作为Teacher model。

## 效果验证
这里是我们模型(finetune+distillation)/预训练模型/finetune模型的图像检索Top10结果。左上角图像作为query,右边按顺序Top1到Top10检索结果。本次实验的support数据集有10万电商数据量(包括鞋子、衣服、裤子等物品)。

我们方法的优势:
+ 符合检索任务基本要求:在保证了类目相似性的前提下,很好实现了图像相似性。
+ 性能好且速度快:通过蒸馏的方法,使得base模型有着large模型类似的检索效果。并且部署到CPU,检索推理时间控制在了100ms以内。

<p style="text-align: center;">
<img src="examples/image_retrieval_result1.jpg" width="400" /><br>
<img src="examples/image_retrieval_result3.jpg" width="400" /><br>
<img src="examples/image_retrieval_result2.jpg" width="400" /><br>
</p>


## Todo
将会在阿里云官网上线相关的解决方案的Jupyter Notebook。
Loading