Skip to content

Commit 6ab9656

Browse files
committed
wip
1 parent 2efbc2b commit 6ab9656

File tree

3 files changed

+562
-331
lines changed

3 files changed

+562
-331
lines changed

llm_analysis/analysis.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def get_TFLOPS_per_gpu(self, wbits, abits) -> float:
228228
assert (higher_bits == 16
229229
), "weight_bits and activation_bits must be 4, 8, or 16"
230230
gemm_TFOPS = self.gpu_config.peak_fp16_TFLOPS
231+
print('XXXXX', self.gpu_config)
231232
return gemm_TFOPS * self.flops_efficiency
232233

233234
def get_pivot(self, wbits, abits) -> float:
@@ -1198,6 +1199,8 @@ def get_latency_fwd_per_layer_attn(
11981199

11991200
attention_projection_flops, attention_compute_flops = (
12001201
self.get_num_flops_fwd_per_layer_attn(batch_size, seq_len))
1202+
assert tp_size != 0, "tp_size must be greater than 0"
1203+
print('XXXXX', self.dtype_config)
12011204
compute_latency = (
12021205
attention_projection_flops / tp_size / (self.get_TFLOPS_per_gpu(
12031206
self.dtype_config.linear_weight_bits,
@@ -2535,12 +2538,13 @@ def training(
25352538
elif activation_recomputation == ActivationRecomputation.ATTN:
25362539
latency_recompute = num_layers_per_gpu * latency_fwd_per_layer_attn_compute
25372540
elif activation_recomputation == ActivationRecomputation.ATTN_COMPUTE:
2538-
latency_recompute = (num_layers_per_gpu *
2539-
self.get_num_flops_total_attn_compute(
2540-
batch_size_per_gpu, seq_len) /
2541-
((self.parallelism_config.tp_size *
2542-
self.parallelism_config.pp_size) *
2543-
self.get_TFLOPS_per_gpu() * 1e12))
2541+
latency_recompute = (
2542+
num_layers_per_gpu * self.get_num_flops_total_attn_compute(
2543+
batch_size_per_gpu, seq_len) /
2544+
((self.parallelism_config.tp_size *
2545+
self.parallelism_config.pp_size) * self.get_TFLOPS_per_gpu(
2546+
self.dtype_config.weight_bits,
2547+
self.dtype_config.activation_bits) * 1e12))
25442548
elif activation_recomputation == ActivationRecomputation.NONE:
25452549
latency_recompute = 0
25462550

llm_analysis/config.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,14 @@ class DtypeConfig:
193193
weight_bits: int = 16 # number of bits for weight
194194
activation_bits: int = 16 # number of bits for activation
195195
embedding_bits: int = 16 # number of bits for the embedding
196-
linear_weight_bits: int = 16 # number of bits for weight in linear layer
197-
linear_activation_bits: int = 16 # number of bits for activation in linear layer
196+
linear_weight_bits: int | None = None # number of bits for weight in linear layer
197+
linear_activation_bits: int | None = None # number of bits for activation in linear layer
198+
199+
def __post_init__(self):
200+
if self.linear_weight_bits is None:
201+
self.linear_weight_bits = self.weight_bits
202+
if self.linear_activation_bits is None:
203+
self.peak_i4_TFLOPS = self.activation_bits
198204

199205

200206
@dataclass
@@ -540,7 +546,11 @@ def list_gpu_configs() -> list:
540546

541547
def list_dtype_configs() -> None:
542548
"""List all predefined data type configs."""
543-
logger.info(dtype_configs.keys())
549+
if not dtype_configs:
550+
logger.warning("No dtype configs loaded")
551+
return []
552+
logger.info(f"Available dtype configs: {list(dtype_configs.keys())}")
553+
return list(dtype_configs.keys())
544554

545555

546556
def get_model_config_by_name(name_or_path: str) -> ModelConfig:

0 commit comments

Comments
 (0)