diff --git a/torchtune/_cli/run.py b/torchtune/_cli/run.py index 1aaee2f6d3..208a0e8225 100644 --- a/torchtune/_cli/run.py +++ b/torchtune/_cli/run.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import os import runpy import sys import textwrap @@ -167,6 +168,11 @@ def _run_cmd(self, args: argparse.Namespace): args.recipe = recipe_path args.recipe_args[config_idx] = config_path + # Make sure user code in current directory is importable + # TODO: This is a temporary fix, figure out how to make runpy and torchrun + # run from this directory + sys.path.append(os.getcwd()) + # Execute recipe if self._is_distributed_args(args): if not supports_distributed: