diff --git a/train-on-cifar/train-on-cifar.lua b/train-on-cifar/train-on-cifar.lua index 69d25c0..87db1e6 100644 --- a/train-on-cifar/train-on-cifar.lua +++ b/train-on-cifar/train-on-cifar.lua @@ -333,26 +333,26 @@ function train(dataset) -- optimize on current mini-batch if opt.optimization == 'CG' then - config = config or {maxIter = opt.maxIter} - optim.cg(feval, parameters, config) + configParams = configParams or {maxIter = opt.maxIter} + optim.cg(feval, parameters, configParams) elseif opt.optimization == 'LBFGS' then - config = config or {learningRate = opt.learningRate, + configParams = configParams or {learningRate = opt.learningRate, maxIter = opt.maxIter, nCorrection = 10} - optim.lbfgs(feval, parameters, config) + optim.lbfgs(feval, parameters, configParams) elseif opt.optimization == 'SGD' then - config = config or {learningRate = opt.learningRate, + configParams = configParams or {learningRate = opt.learningRate, weightDecay = opt.weightDecay, momentum = opt.momentum, learningRateDecay = 5e-7} - optim.sgd(feval, parameters, config) + optim.sgd(feval, parameters, configParams) elseif opt.optimization == 'ASGD' then - config = config or {eta0 = opt.learningRate, + configParams = configParams or {eta0 = opt.learningRate, t0 = nbTrainingPatches * opt.t0} - _,_,average = optim.asgd(feval, parameters, config) + _,_,average = optim.asgd(feval, parameters, configParams) else error('unknown optimization method')