Skip to content

Commit 5fbe41e

Browse files
committed
Added --data-path to densenet and alexnet
1 parent 7392542 commit 5fbe41e

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

applications/vision/alexnet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
parser.add_argument(
2323
'--num-classes', action='store', default=1000, type=int,
2424
help='number of ImageNet classes (default: 1000)', metavar='NUM')
25+
parser.add_argument(
26+
'--data-path', action='store', default=None, type=str,
27+
help='Path to top-level imagenet directory. default: None')
2528
lbann.contrib.args.add_optimizer_arguments(parser)
2629
args = parser.parse_args()
2730

@@ -64,7 +67,8 @@
6467
opt = lbann.contrib.args.create_optimizer(args)
6568

6669
# Setup data reader
67-
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes)
70+
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
71+
data_path=args.data_path)
6872

6973
# Setup trainer
7074
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size)

applications/vision/densenet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ def get_args():
428428
parser.add_argument("--print-matrix-summary", dest="print_matrix_summary",
429429
action="store_const",
430430
const=True, default=False)
431+
parser.add_argument('--data-path', action='store', default=None, type=str,
432+
help='Path to top-level imagenet directory. default: None')
431433
args = parser.parse_args()
432434
return args
433435

@@ -438,7 +440,7 @@ def set_up_experiment(args,
438440
labels):
439441
algo = lbann.BatchedIterativeOptimizer("sgd", epoch_count=args.num_epochs)
440442

441-
443+
442444
# Set up objective function
443445
cross_entropy = lbann.CrossEntropy([probs, labels])
444446
layers = list(lbann.traverse_layer_graph(input_))
@@ -472,7 +474,9 @@ def set_up_experiment(args,
472474
callbacks=callbacks)
473475

474476
# Set up data reader
475-
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, small_testing=True)
477+
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
478+
small_testing=True,
479+
data_path=args.data_path)
476480

477481
percentage = 0.001 * 2 * (args.mini_batch_size / 16) * 2
478482

0 commit comments

Comments
 (0)