10
10
from torch .cuda .amp import autocast
11
11
import torch .distributed .nn
12
12
import torch .distributed as dist
13
+ import torch .nn .functional as F
13
14
14
15
from cn_clip .clip .model import convert_state_dict
15
16
16
17
17
18
def is_master (args ):
18
19
return args .rank == 0
19
20
20
- def get_loss (model , images , texts , loss_img , loss_txt , args , accum_image_features = None , accum_text_features = None , accum_idx = - 1 ):
21
+ def get_loss (model , images , texts , loss_img , loss_txt , args , accum_image_features = None , accum_text_features = None , accum_idx = - 1 , teacher_model = None , teacher_accum_image_features = None ):
21
22
if args .accum_freq == 1 :
22
23
image_features , text_features , logit_scale = model (images , texts , args .mask_ratio )
24
+
25
+ if args .distllation :
26
+ with torch .no_grad ():
27
+ # different teacher model has different output
28
+ output = teacher_model .module .get_feature (images )
29
+ if (isinstance (output , tuple )):
30
+ teacher_image_features = output [0 ]
31
+ else :
32
+ teacher_image_features = output
23
33
else :
24
34
assert accum_image_features and accum_text_features and accum_idx != - 1
25
35
chunk_image_features , chunk_text_features , logit_scale = model (images , texts , args .mask_ratio )
36
+
37
+ if args .distllation :
38
+ with torch .no_grad ():
39
+ # different teacher model has different output
40
+ output = teacher_model .module .get_feature (images )
41
+ if (isinstance (output , tuple )):
42
+ teacher_chunk_image_features = output [0 ]
43
+ else :
44
+ teacher_chunk_image_features = output
45
+ teacher_image_features = torch .cat (
46
+ teacher_accum_image_features [:accum_idx ] + [teacher_chunk_image_features ] + teacher_accum_image_features [accum_idx + 1 :])
47
+
26
48
image_features = torch .cat (
27
49
accum_image_features [:accum_idx ] + [chunk_image_features ] + accum_image_features [accum_idx + 1 :])
28
50
text_features = torch .cat (
@@ -36,13 +58,17 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
36
58
if args .gather_with_grad :
37
59
all_image_features = torch .cat (torch .distributed .nn .all_gather (image_features ), dim = 0 )
38
60
all_text_features = torch .cat (torch .distributed .nn .all_gather (text_features ), dim = 0 )
61
+
62
+ if args .distllation :
63
+ all_teacher_image_features = torch .cat (torch .distributed .nn .all_gather (teacher_image_features ), dim = 0 )
39
64
else :
40
65
gathered_image_features = [
41
66
torch .zeros_like (image_features ) for _ in range (world_size )
42
67
]
43
68
gathered_text_features = [
44
69
torch .zeros_like (text_features ) for _ in range (world_size )
45
70
]
71
+
46
72
dist .all_gather (gathered_image_features , image_features )
47
73
dist .all_gather (gathered_text_features , text_features )
48
74
@@ -61,10 +87,25 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
61
87
logits_per_image = logit_scale * all_image_features @ all_text_features .t ()
62
88
logits_per_text = logits_per_image .t ()
63
89
90
+ if args .distllation :
91
+ gathered_teacher_image_features = [
92
+ torch .zeros_like (teacher_image_features ) for _ in range (world_size )
93
+ ]
94
+ dist .all_gather (gathered_teacher_image_features , teacher_image_features )
95
+ all_teacher_image_features = torch .cat (
96
+ [teacher_image_features ]
97
+ + gathered_teacher_image_features [:rank ]
98
+ + gathered_teacher_image_features [rank + 1 :]
99
+ )
100
+ kd_loss = cosineSimilarityLoss (all_teacher_image_features , all_image_features )
101
+
64
102
else :
65
103
logits_per_image = logit_scale * image_features @ text_features .t ()
66
104
logits_per_text = logit_scale * text_features @ image_features .t ()
67
105
106
+ if args .distllation :
107
+ kd_loss = cosineSimilarityLoss (teacher_image_features , image_features )
108
+
68
109
ground_truth = torch .arange (len (logits_per_image )).long ()
69
110
ground_truth = ground_truth .cuda (args .local_device_rank , non_blocking = True )
70
111
@@ -79,6 +120,9 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
79
120
t2i_acc = (logits_per_text .argmax (- 1 ) == ground_truth ).sum () / len (logits_per_text )
80
121
acc = {"i2t" : i2t_acc , "t2i" : t2i_acc }
81
122
123
+ if args .distllation :
124
+ total_loss += kd_loss * args .kd_loss_weight
125
+
82
126
return total_loss , acc
83
127
84
128
def freeze_vision_bn (args , model ):
@@ -89,7 +133,7 @@ def freeze_vision_bn(args, model):
89
133
if isinstance (m , nn .BatchNorm2d ):
90
134
m .eval ()
91
135
92
- def train (model , data , epoch , optimizer , scaler , scheduler , args , global_trained_steps ):
136
+ def train (model , data , epoch , optimizer , scaler , scheduler , args , global_trained_steps , teacher_model = None ):
93
137
# os.environ["WDS_EPOCH"] = str(epoch)
94
138
95
139
model .train ()
@@ -112,6 +156,8 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
112
156
113
157
if args .accum_freq > 1 :
114
158
accum_images , accum_texts , accum_image_features , accum_text_features = [], [], [], []
159
+ if args .distllation :
160
+ teacher_accum_image_features = []
115
161
116
162
end = time .time ()
117
163
epoch_trained_steps = 0
@@ -142,22 +188,36 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
142
188
# with automatic mixed precision.
143
189
if args .precision == "amp" :
144
190
with autocast ():
145
- total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args )
191
+ if args .distllation :
192
+ total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args , teacher_model = teacher_model )
193
+ else :
194
+ total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args )
146
195
scaler .scale (total_loss ).backward ()
147
196
scaler .step (optimizer )
148
197
scaler .update ()
149
198
150
199
else :
151
- total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args )
200
+ if args .distllation :
201
+ total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args , teacher_model = teacher_model )
202
+ else :
203
+ total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args )
152
204
total_loss .backward ()
153
205
optimizer .step ()
154
206
else :
155
207
# First, cache the features without any gradient tracking.
156
208
with torch .no_grad ():
157
209
with autocast (enabled = (args .precision == "amp" )):
158
210
chunk_image_features , chunk_text_features , _ = model (images , texts )
211
+ if args .distllation :
212
+ output = teacher_model .module .get_feature (images )
213
+ if (len (output ) == 2 ):
214
+ teacher_chunk_image_features = output [0 ]
215
+ else :
216
+ teacher_chunk_image_features = output
159
217
accum_image_features .append (chunk_image_features )
160
218
accum_text_features .append (chunk_text_features )
219
+ if args .distllation :
220
+ teacher_accum_image_features .append (teacher_chunk_image_features )
161
221
162
222
accum_images .append (images )
163
223
accum_texts .append (texts )
@@ -177,7 +237,10 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
177
237
with autocast (enabled = (args .precision == "amp" )):
178
238
# `total_loss` and `acc` are coarsely sampled, taking only the last result in the loop.
179
239
# Although each result should be the same in theory, it will be slightly different in practice
180
- total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args , accum_image_features , accum_text_features , j )
240
+ if args .distllation :
241
+ total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args , accum_image_features , accum_text_features , j , teacher_model , teacher_accum_image_features )
242
+ else :
243
+ total_loss , acc = get_loss (model , images , texts , loss_img , loss_txt , args , accum_image_features , accum_text_features , j )
181
244
if args .precision == "amp" :
182
245
scaler .scale (total_loss ).backward ()
183
246
else :
@@ -192,6 +255,8 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
192
255
# reset gradient accum, if enabled
193
256
if args .accum_freq > 1 :
194
257
accum_images , accum_texts , accum_image_features , accum_text_features = [], [], [], []
258
+ if args .distllation :
259
+ teacher_accum_image_features = []
195
260
196
261
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
197
262
m .logit_scale .data = torch .clamp (m .logit_scale .data , 0 , 4.6052 )
@@ -337,3 +402,18 @@ def evaluate(model, data, epoch, args, steps):
337
402
f"logit_scale: { model .module .logit_scale .data :.3f} | "
338
403
f"Valid Batch Size: { batch_size } "
339
404
)
405
+
406
+ def cosineSimilarityLoss (feature1 , feature2 ):
407
+ scale_factor_h = feature1 .shape [0 ] / feature2 .size (0 )
408
+ scale_factor_w = feature1 .shape [1 ] / feature2 .size (1 )
409
+
410
+ feature2_interpolated = F .interpolate (feature2 .unsqueeze (0 ).unsqueeze (0 ),
411
+ size = (feature1 .shape [0 ], feature1 .shape [1 ]),
412
+ mode = 'bilinear' ,
413
+ align_corners = False )
414
+ feature2_interpolated = feature2_interpolated .squeeze (0 ).squeeze (0 )
415
+
416
+
417
+ cosine_sim = F .cosine_similarity (feature1 , feature2_interpolated , dim = 1 )
418
+ similarity_loss = 1 - cosine_sim .mean ()
419
+ return similarity_loss
0 commit comments