From c132135aaca8a00be5981eb4b4a5bff14d962648 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 27 Jun 2024 15:36:44 -0700 Subject: [PATCH 1/4] error --- scripts/train/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/train/train.py b/scripts/train/train.py index 655b5de938..200b2e201e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -322,6 +322,12 @@ def main(cfg: DictConfig) -> Trainer: if train_cfg.metadata is not None: # Flatten the metadata for logging logged_cfg.pop('metadata', None) + common_keys = set(logged_cfg.keys()) & set(train_cfg.metadata.keys()) + if common_keys: + raise ValueError( + f'Keys {common_keys} are already present in the config. Please rename them in metadata.', + ) + logged_cfg.update(train_cfg.metadata, merge=True) if mosaicml_logger is not None: mosaicml_logger.log_metrics(train_cfg.metadata) From 86ac3bbf5d47943d9ccb14235933d3ee8881f83d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 28 Jun 2024 16:40:51 -0700 Subject: [PATCH 2/4] allow turning off metadata flattening --- llmfoundry/utils/config_utils.py | 1 + scripts/train/train.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 79cdc225b2..a06a6f38c9 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -171,6 +171,7 @@ class TrainConfig: # Metadata metadata: Optional[Dict[str, Any]] = None + flatten_metadata: bool = True run_name: Optional[str] = None # Resumption diff --git a/scripts/train/train.py b/scripts/train/train.py index 200b2e201e..4f7c93af88 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -320,15 +320,18 @@ def main(cfg: DictConfig) -> Trainer: loggers.append(mosaicml_logger) if train_cfg.metadata is not None: - # Flatten the metadata for logging - logged_cfg.pop('metadata', None) - common_keys = set(logged_cfg.keys()) & set(train_cfg.metadata.keys()) - if common_keys: - raise ValueError( - f'Keys {common_keys} are already present in the config. Please rename them in metadata.', - ) + # Optionally flatten the metadata for logging + if train_cfg.flatten_metadata: + logged_cfg.pop('metadata', None) + common_keys = set(logged_cfg.keys() + ) & set(train_cfg.metadata.keys()) + if common_keys: + raise ValueError( + f'Keys {common_keys} are already present in the config. Please rename them in metadata.', + ) + + logged_cfg.update(train_cfg.metadata, merge=True) - logged_cfg.update(train_cfg.metadata, merge=True) if mosaicml_logger is not None: mosaicml_logger.log_metrics(train_cfg.metadata) mosaicml_logger._flush_metadata(force_flush=True) From db54cab69d7e78daa2864db225f3af5018e09b4d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 28 Jun 2024 16:49:00 -0700 Subject: [PATCH 3/4] adjust error --- scripts/train/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 4f7c93af88..8e8ff70a19 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -323,11 +323,14 @@ def main(cfg: DictConfig) -> Trainer: # Optionally flatten the metadata for logging if train_cfg.flatten_metadata: logged_cfg.pop('metadata', None) - common_keys = set(logged_cfg.keys() - ) & set(train_cfg.metadata.keys()) + common_keys = set( + logged_cfg.keys(), + ) & set(train_cfg.metadata.keys()) if common_keys: raise ValueError( - f'Keys {common_keys} are already present in the config. Please rename them in metadata.', + f'Keys {common_keys} are already present in the config. Please rename them in metadata ' + + + 'or set flatten_metadata=False to avoid flattening the metadata in the logged config.', ) logged_cfg.update(train_cfg.metadata, merge=True) From 79fda344d248cd5badb6f06d2d5f24fb25f91078 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:23:25 -0400 Subject: [PATCH 4/4] Update scripts/train/train.py Co-authored-by: Irene Dea --- scripts/train/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 8e8ff70a19..134058a595 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -326,7 +326,7 @@ def main(cfg: DictConfig) -> Trainer: common_keys = set( logged_cfg.keys(), ) & set(train_cfg.metadata.keys()) - if common_keys: + if len(common_keys) > 0: raise ValueError( f'Keys {common_keys} are already present in the config. Please rename them in metadata ' +