Skip to content

Commit d841e40

Browse files
authored
fix allreduce latency and memory usage calculation when using tp (#28)
* fix allreduce latency and mem usage when tp is in use * update latency calcuation in allgather
1 parent dfd4da9 commit d841e40

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

llm_analysis/analysis.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,9 @@ def get_memory_optimizer_state_and_gradient_per_layer(
475475

476476
memory_optimizer_state_others_per_layer = op_bytes_per_params * (
477477
(self.get_num_params_per_layer_attn() +
478-
+self.get_num_params_per_layer_router() +
479-
self.get_num_params_per_layer_layernorm())
480-
) / self.parallelism_config.tp_size / sharded_dp_size
478+
+self.get_num_params_per_layer_router()) /
479+
self.parallelism_config.tp_size +
480+
self.get_num_params_per_layer_layernorm()) / sharded_dp_size
481481

482482
memory_optimizer_state_per_layer = memory_optimizer_state_mlp_per_layer + memory_optimizer_state_others_per_layer
483483

@@ -1218,9 +1218,9 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
12181218
elems_per_all_reduce = (2 * batch_size * seq_len *
12191219
self.model_config.hidden_dim * (tp_size - 1) /
12201220
tp_size)
1221-
latency_per_all_reduce = (
1222-
elems_per_all_reduce * dtype_bytes /
1223-
(self.gpu_config.intra_node_bandwidth_in_GB_per_sec * 10**9))
1221+
# assuming tp_size <= number of GPUs per node, thus using intra-node bandwidth
1222+
latency_per_all_reduce = (elems_per_all_reduce * dtype_bytes /
1223+
(self.get_intra_node_bandwidth() * 10**9))
12241224

12251225
return max(
12261226
latency_per_all_reduce,
@@ -1230,6 +1230,7 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
12301230
def get_latency_fwd_per_layer_shared_dp_comm(self) -> float:
12311231
dp_size = self.parallelism_config.dp_size
12321232
ep_size = self.parallelism_config.ep_size
1233+
tp_size = self.parallelism_config.tp_size
12331234

12341235
def time_allgather(S, n, B):
12351236
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allgather
@@ -1243,15 +1244,17 @@ def time_allgather(S, n, B):
12431244
self.get_num_params_per_layer_layernorm()
12441245
) * self.dtype_config.weight_bits / BITS_PER_BYTE
12451246

1246-
latency_allgather_params_mlp = time_allgather(
1247-
params_bytes_mlp, dp_size / ep_size,
1248-
(self.get_intra_node_bandwidth()
1249-
if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9)
1247+
# assuming tp and dp are preferred when sharding intra node, pp is only applied across nodes
1248+
# when (dp_size * tp_size) <= 8, the data parallel processes are within a node
1249+
bandwidth = self.get_intra_node_bandwidth() if (
1250+
dp_size * tp_size) <= 8 else self.get_inter_node_bandwidth()
1251+
1252+
latency_allgather_params_mlp = time_allgather(params_bytes_mlp,
1253+
dp_size / ep_size,
1254+
bandwidth * 10**9)
12501255

12511256
latency_allgather_params_non_mlp = time_allgather(
1252-
params_bytes_non_mlp, dp_size,
1253-
(self.get_intra_node_bandwidth()
1254-
if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9)
1257+
params_bytes_non_mlp, dp_size, bandwidth * 10**9)
12551258

12561259
latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp
12571260

0 commit comments

Comments
 (0)