@@ -475,9 +475,9 @@ def get_memory_optimizer_state_and_gradient_per_layer(
475
475
476
476
memory_optimizer_state_others_per_layer = op_bytes_per_params * (
477
477
(self .get_num_params_per_layer_attn () +
478
- + self .get_num_params_per_layer_router () +
479
- self .get_num_params_per_layer_layernorm ())
480
- ) / self .parallelism_config . tp_size / sharded_dp_size
478
+ + self .get_num_params_per_layer_router ()) /
479
+ self .parallelism_config . tp_size +
480
+ self .get_num_params_per_layer_layernorm ()) / sharded_dp_size
481
481
482
482
memory_optimizer_state_per_layer = memory_optimizer_state_mlp_per_layer + memory_optimizer_state_others_per_layer
483
483
@@ -1218,9 +1218,9 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
1218
1218
elems_per_all_reduce = (2 * batch_size * seq_len *
1219
1219
self .model_config .hidden_dim * (tp_size - 1 ) /
1220
1220
tp_size )
1221
- latency_per_all_reduce = (
1222
- elems_per_all_reduce * dtype_bytes /
1223
- (self .gpu_config . intra_node_bandwidth_in_GB_per_sec * 10 ** 9 ))
1221
+ # assuming tp_size <= number of GPUs per node, thus using intra-node bandwidth
1222
+ latency_per_all_reduce = ( elems_per_all_reduce * dtype_bytes /
1223
+ (self .get_intra_node_bandwidth () * 10 ** 9 ))
1224
1224
1225
1225
return max (
1226
1226
latency_per_all_reduce ,
@@ -1230,6 +1230,7 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
1230
1230
def get_latency_fwd_per_layer_shared_dp_comm (self ) -> float :
1231
1231
dp_size = self .parallelism_config .dp_size
1232
1232
ep_size = self .parallelism_config .ep_size
1233
+ tp_size = self .parallelism_config .tp_size
1233
1234
1234
1235
def time_allgather (S , n , B ):
1235
1236
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allgather
@@ -1243,15 +1244,17 @@ def time_allgather(S, n, B):
1243
1244
self .get_num_params_per_layer_layernorm ()
1244
1245
) * self .dtype_config .weight_bits / BITS_PER_BYTE
1245
1246
1246
- latency_allgather_params_mlp = time_allgather (
1247
- params_bytes_mlp , dp_size / ep_size ,
1248
- (self .get_intra_node_bandwidth ()
1249
- if dp_size <= 8 else self .get_inter_node_bandwidth ()) * 10 ** 9 )
1247
+ # assuming tp and dp are preferred when sharding intra node, pp is only applied across nodes
1248
+ # when (dp_size * tp_size) <= 8, the data parallel processes are within a node
1249
+ bandwidth = self .get_intra_node_bandwidth () if (
1250
+ dp_size * tp_size ) <= 8 else self .get_inter_node_bandwidth ()
1251
+
1252
+ latency_allgather_params_mlp = time_allgather (params_bytes_mlp ,
1253
+ dp_size / ep_size ,
1254
+ bandwidth * 10 ** 9 )
1250
1255
1251
1256
latency_allgather_params_non_mlp = time_allgather (
1252
- params_bytes_non_mlp , dp_size ,
1253
- (self .get_intra_node_bandwidth ()
1254
- if dp_size <= 8 else self .get_inter_node_bandwidth ()) * 10 ** 9 )
1257
+ params_bytes_non_mlp , dp_size , bandwidth * 10 ** 9 )
1255
1258
1256
1259
latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp
1257
1260
0 commit comments