-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathtrain.py
More file actions
111 lines (93 loc) · 4.71 KB
/
train.py
File metadata and controls
111 lines (93 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import subprocess
import os
from easydict import EasyDict
import GPUtil
import sys, time
gpu = GPUtil.getAvailable(limit=4, excludeID=[6,7])
print("gpu available : ", gpu)
args = EasyDict(dict(gpu="2",
batch_size=12,
train_signal="/data/jongho/data/2019challenge/dcunet/clean_speech/train/",
train_noise="/data/jongho/data/2019challenge/dcunet/noisy_noise/train/",
test_signal="/data/jongho/data/2019challenge/ss/reference/16k/clean_testset_wav/",
test_noise="/data/jongho/data/2019challenge/ss/reference/16k/noisy_testset_wav/",
sequence_length=16384,
num_step=40000,
validation_interval=500,
num_workers=12,
ckpt="unet/ckpt.pth",
model_complexity=45,
lr=0.01,
num_signal=0,
num_noise=0,
optimizer="adam",
lr_decay=0.5,
momentum=0,
multi_gpu=False,
complex=True,
model_depth=10,
swa=False,
loss="wsdr",
log_amp=False,
metric="pesq",
train_dataset="se",
valid_dataset="se",
preload=False,
padding_mode="reflect"))
se_y_train = "/data/jongho/data/2019challenge/ss/reference/16k/clean_trainset_28spk_wav/"
se_x_train = "/data/jongho/data/2019challenge/ss/reference/16k/noisy_trainset_28spk_wav/"
# mix_y_train = "/data/jongho/data/2019challenge/ss/clean_speech/train/"
# mix_x_train = "/data/jongho/data/2019challenge/ss/noisy_noise/train/"
mix_y_train = "/data/jongho/data/2019challenge/dcunet/clean_speech/train/"
mix_x_train = "/data/jongho/data/2019challenge/dcunet/demand/train/"
# mix_y_train = "/data/jongho/data/2019challenge/ss/dataset/train/speech/"
# mix_x_train = "/data/jongho/data/2019challenge/ss/dataset/train/noise/"
for model_complexity in [45, 90]:
for model_depth in [10, 20]:
# skip cnfig
if model_complexity == 90 and model_depth == 10:
continue
if model_complexity == 45 and model_depth == 20:
continue
for complex in [False, True]:
for log in [False]:
for train_dataset in ['se']:
for optimizer, lr in [('adam', 0.01)]:
for padding_mode in ['zeros']:
while not gpu:
sleep_sec = 600
print(f"no gpu available, sleep {sleep_sec}s...")
time.sleep(sleep_sec)
gpu = GPUtil.getAvailable(limit=4, excludeID=[6,7])
command = [f"/miniconda/bin/python", f"{os.getcwd()}/train_dcunet.py"]
args.train_dataset = train_dataset
if train_dataset == "se":
args.train_signal = se_y_train
args.train_noise = se_x_train
else:
args.train_signal = mix_y_train
args.train_noise = mix_x_train
args.padding_mode = padding_mode
args.model_depth = model_depth
args.gpu = str(gpu.pop())
args.model_complexity = model_complexity
args.ckpt = f"demand_experiment_report/190717_{log}_dp{model_depth}_{train_dataset}_sz{model_complexity}_{padding_mode}_comp_{complex}.pth"
args.optimizer = optimizer
args.lr = lr
for k,v in args.items():
if isinstance(v, bool):
pass
else:
command.append(f"--{k}")
command.append(f"{v}")
if log:
command.append("--log_amp")
if args.preload:
command.append("--preload")
if complex:
command.append("--complex")
print("command : {", command, "}")
subprocess.Popen(command, shell=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
time.sleep(1)
# exit()
time.sleep(86400*10)