We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3c450ef commit 10b02e0Copy full SHA for 10b02e0
torchtune/models/llama3_2_vision/_convert_weights.py
@@ -148,8 +148,10 @@ def llama3_vision_tune_to_meta(
148
149
# Calculate fusion_interval: layer interval where cross attention layers are fused
150
num_layers = max(_layer_num(k) for k in state_dict if "layers" in k) + 1
151
- num_fusion_layers = (
152
- max(_layer_num(k) for k in state_dict if "cross_attention_layers" in k) + 1
+ # Get the number of unique fusion layers.
+ # Keys have the form decoder.fusion_layer.i. ... where i is the layer number
153
+ num_fusion_layers = len(
154
+ set([k.split(".")[2] for k in state_dict if "fusion_layer" in k])
155
)
156
assert (
157
num_layers % num_fusion_layers == 0
0 commit comments