Skip to content

Commit 9f285fe

Browse files
authored
load hparam from yaml file
1 parent 0571215 commit 9f285fe

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from data_utils import TextMelLoader, TextMelCollate
1515
from loss_function import Tacotron2Loss
1616
from logger import Tacotron2Logger
17-
from hparams import create_hparams
17+
from config.hparams import HParam
1818

1919

2020
def reduce_tensor(tensor, n_gpus):
@@ -272,10 +272,12 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
272272
parser.add_argument('--group_name', type=str, default='group_name',
273273
required=False, help='Distributed group name')
274274
parser.add_argument('--hparams', type=str,
275-
required=False, help='comma separated name=value pairs')
275+
required=True, help='path to the yaml config file ')
276276

277277
args = parser.parse_args()
278-
hparams = create_hparams(args.hparams)
278+
hparams = HParam(args.hparams)
279+
with open(args.hparams, "r") as f:
280+
hp_str = "".join(f.readlines())
279281

280282
torch.backends.cudnn.enabled = hparams.cudnn_enabled
281283
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark

0 commit comments

Comments
 (0)