@@ -228,6 +228,7 @@ def get_TFLOPS_per_gpu(self, wbits, abits) -> float:
228
228
assert (higher_bits == 16
229
229
), "weight_bits and activation_bits must be 4, 8, or 16"
230
230
gemm_TFOPS = self .gpu_config .peak_fp16_TFLOPS
231
+ print ('XXXXX' , self .gpu_config )
231
232
return gemm_TFOPS * self .flops_efficiency
232
233
233
234
def get_pivot (self , wbits , abits ) -> float :
@@ -1198,6 +1199,8 @@ def get_latency_fwd_per_layer_attn(
1198
1199
1199
1200
attention_projection_flops , attention_compute_flops = (
1200
1201
self .get_num_flops_fwd_per_layer_attn (batch_size , seq_len ))
1202
+ assert tp_size != 0 , "tp_size must be greater than 0"
1203
+ print ('XXXXX' , self .dtype_config )
1201
1204
compute_latency = (
1202
1205
attention_projection_flops / tp_size / (self .get_TFLOPS_per_gpu (
1203
1206
self .dtype_config .linear_weight_bits ,
@@ -2535,12 +2538,13 @@ def training(
2535
2538
elif activation_recomputation == ActivationRecomputation .ATTN :
2536
2539
latency_recompute = num_layers_per_gpu * latency_fwd_per_layer_attn_compute
2537
2540
elif activation_recomputation == ActivationRecomputation .ATTN_COMPUTE :
2538
- latency_recompute = (num_layers_per_gpu *
2539
- self .get_num_flops_total_attn_compute (
2540
- batch_size_per_gpu , seq_len ) /
2541
- ((self .parallelism_config .tp_size *
2542
- self .parallelism_config .pp_size ) *
2543
- self .get_TFLOPS_per_gpu () * 1e12 ))
2541
+ latency_recompute = (
2542
+ num_layers_per_gpu * self .get_num_flops_total_attn_compute (
2543
+ batch_size_per_gpu , seq_len ) /
2544
+ ((self .parallelism_config .tp_size *
2545
+ self .parallelism_config .pp_size ) * self .get_TFLOPS_per_gpu (
2546
+ self .dtype_config .weight_bits ,
2547
+ self .dtype_config .activation_bits ) * 1e12 ))
2544
2548
elif activation_recomputation == ActivationRecomputation .NONE :
2545
2549
latency_recompute = 0
2546
2550
0 commit comments