Skip to content

Commit 2efbc2b

Browse files
committed
wip
1 parent 3a2b9c0 commit 2efbc2b

File tree

2 files changed

+292
-253
lines changed

2 files changed

+292
-253
lines changed

llm_analysis/config.py

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -410,21 +410,37 @@ def get_model_config_from_hf(name: str, ) -> ModelConfig:
410410
def read_configs(config_dir_name: str, type="model") -> dict:
411411
"""Read configs from a directory."""
412412
configs = {}
413-
for filename in os.listdir(config_dir_name):
413+
logger.info(f"Reading {type} configs from directory: {config_dir_name}")
414+
415+
if not os.path.exists(config_dir_name):
416+
logger.error(f"Config directory does not exist: {config_dir_name}")
417+
return configs
418+
419+
config_files = os.listdir(config_dir_name)
420+
logger.info(f"Found {len(config_files)} files in {config_dir_name}")
421+
422+
for filename in config_files:
414423
filepath = os.path.join(config_dir_name, filename)
415-
with open(filepath, "r") as f:
416-
config_json = json.load(f)
417-
if type == "model":
418-
config = ModelConfig(**config_json)
419-
elif type == "gpu":
420-
config = GPUConfig(**config_json)
421-
elif type == "dtype":
422-
config = DtypeConfig(**config_json)
423-
else:
424-
assert False, f"unknown config type when reading: {type}"
425-
if config.name not in configs:
426-
configs[config.name] = config
427-
logger.info(f"Loaded {len(configs)} configs from {config_dir_name}")
424+
logger.info(f"Reading config file: {filepath}")
425+
try:
426+
with open(filepath, "r") as f:
427+
config_json = json.load(f)
428+
if type == "model":
429+
config = ModelConfig(**config_json)
430+
elif type == "gpu":
431+
config = GPUConfig(**config_json)
432+
elif type == "dtype":
433+
config = DtypeConfig(**config_json)
434+
else:
435+
assert False, f"unknown config type when reading: {type}"
436+
if config.name not in configs:
437+
configs[config.name] = config
438+
except Exception as e:
439+
logger.error(f"Error reading config file {filepath}: {str(e)}")
440+
441+
logger.info(
442+
f"Successfully loaded {len(configs)} {type} configs from {config_dir_name}"
443+
)
428444
return configs
429445

430446

@@ -481,29 +497,45 @@ def get_hf_models_by_type_and_task(
481497
def populate_model_and_gpu_configs() -> None:
482498
"""Populate model, gpu, and data type configs from the pre-defined json files."""
483499
global model_configs, gpu_configs, dtype_configs
484-
model_configs = read_configs(Path(__file__).parent /
485-
Path(MODEL_CONFIG_DIR_NAME),
486-
type="model")
487-
gpu_configs = read_configs(Path(__file__).parent /
488-
Path(GPU_CONFIG_DIR_NAME),
489-
type="gpu")
490-
491-
dtype_configs = read_configs(Path(__file__).parent /
492-
Path(DTYPE_CONFIG_DIR_NAME),
493-
type="dtype")
500+
501+
logger.info("Starting to populate configs...")
502+
503+
# Get the absolute paths
504+
base_path = Path(__file__).parent
505+
model_path = base_path / MODEL_CONFIG_DIR_NAME
506+
gpu_path = base_path / GPU_CONFIG_DIR_NAME
507+
dtype_path = base_path / DTYPE_CONFIG_DIR_NAME
508+
494509
logger.info(
495-
f"Populated {len(model_configs)} model configs, {len(gpu_configs)} gpu configs, {len(dtype_configs)} dtype configs"
510+
f"Using paths:\n Models: {model_path}\n GPUs: {gpu_path}\n Dtypes: {dtype_path}"
496511
)
497512

513+
model_configs = read_configs(model_path, type="model")
514+
gpu_configs = read_configs(gpu_path, type="gpu")
515+
dtype_configs = read_configs(dtype_path, type="dtype")
516+
517+
logger.info(f"Config population complete:\n"
518+
f" - {len(model_configs)} model configs\n"
519+
f" - {len(gpu_configs)} GPU configs\n"
520+
f" - {len(dtype_configs)} dtype configs")
521+
498522

499523
def list_model_configs() -> None:
500524
"""List all predefined model configs."""
501525
logger.info(model_configs.keys())
502526

503527

504-
def list_gpu_configs() -> None:
505-
"""List all predefined gpu configs."""
506-
logger.info(gpu_configs.keys())
528+
def list_gpu_configs() -> list:
529+
"""List all predefined gpu configs.
530+
531+
Returns:
532+
list: List of available GPU config names
533+
"""
534+
if not gpu_configs:
535+
logger.warning("No GPU configs loaded")
536+
return []
537+
logger.info(f"Available GPU configs: {list(gpu_configs.keys())}")
538+
return list(gpu_configs.keys())
507539

508540

509541
def list_dtype_configs() -> None:

0 commit comments

Comments
 (0)