Skip to content

Commit 10b02e0

Browse files
authored
Update fusion layer counting logic for Llama 3.2 weight conversion (#1722)
1 parent 3c450ef commit 10b02e0

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchtune/models/llama3_2_vision/_convert_weights.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,10 @@ def llama3_vision_tune_to_meta(
148148

149149
# Calculate fusion_interval: layer interval where cross attention layers are fused
150150
num_layers = max(_layer_num(k) for k in state_dict if "layers" in k) + 1
151-
num_fusion_layers = (
152-
max(_layer_num(k) for k in state_dict if "cross_attention_layers" in k) + 1
151+
# Get the number of unique fusion layers.
152+
# Keys have the form decoder.fusion_layer.i. ... where i is the layer number
153+
num_fusion_layers = len(
154+
set([k.split(".")[2] for k in state_dict if "fusion_layer" in k])
153155
)
154156
assert (
155157
num_layers % num_fusion_layers == 0

0 commit comments

Comments
 (0)