Skip to content

Commit 37755b4

Browse files
committed
Add automatic upgrade for solver type and update examples and doc
1 parent e1416b1 commit 37755b4

18 files changed

+231
-34
lines changed

docs/tutorial/solver.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ The responsibilities of learning are divided between the Solver for overseeing t
88

99
The Caffe solvers are:
1010

11-
- Stochastic Gradient Descent (`SGD`),
12-
- AdaDelta (`ADADELTA`),
13-
- Adaptive Gradient (`ADAGRAD`),
14-
- Adam (`ADAM`),
15-
- Nesterov's Accelerated Gradient (`NESTEROV`) and
16-
- RMSprop (`RMSPROP`)
11+
- Stochastic Gradient Descent (`type: "SGD"`),
12+
- AdaDelta (`type: "AdaDelta"`),
13+
- Adaptive Gradient (`type: "AdaGrad"`),
14+
- Adam (`type: "Adam"`),
15+
- Nesterov's Accelerated Gradient (`type: "Nesterov"`) and
16+
- RMSprop (`type: "RMSProp"`)
1717

1818
The solver
1919

@@ -51,7 +51,7 @@ The parameter update $$\Delta W$$ is formed by the solver from the error gradien
5151

5252
### SGD
5353

54-
**Stochastic gradient descent** (`solver_type: SGD`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$.
54+
**Stochastic gradient descent** (`type: "SGD"`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$.
5555
The **learning rate** $$ \alpha $$ is the weight of the negative gradient.
5656
The **momentum** $$ \mu $$ is the weight of the previous update.
5757

@@ -113,7 +113,7 @@ If learning diverges (e.g., you start to see very large or `NaN` or `inf` loss v
113113

114114
### AdaDelta
115115

116-
The **AdaDelta** (`solver_type: ADADELTA`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are
116+
The **AdaDelta** (`type: "AdaDelta"`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are
117117

118118
$$
119119
\begin{align}
@@ -125,7 +125,7 @@ E[g^2]_t &= \delta{E[g^2]_{t-1} } + (1-\delta)g_{t}^2
125125
\end{align}
126126
$$
127127

128-
and
128+
and
129129

130130
$$
131131
(W_{t+1})_i =
@@ -139,7 +139,7 @@ $$
139139

140140
### AdaGrad
141141

142-
The **adaptive gradient** (`solver_type: ADAGRAD`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words.
142+
The **adaptive gradient** (`type: "AdaGrad"`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words.
143143
Given the update information from all previous iterations $$ \left( \nabla L(W) \right)_{t'} $$ for $$ t' \in \{1, 2, ..., t\} $$,
144144
the update formulas proposed by [1] are as follows, specified for each component $$i$$ of the weights $$W$$:
145145

@@ -159,7 +159,7 @@ Note that in practice, for weights $$ W \in \mathcal{R}^d $$, AdaGrad implementa
159159

160160
### Adam
161161

162-
The **Adam** (`solver_type: ADAM`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are
162+
The **Adam** (`type: "Adam"`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are
163163

164164
$$
165165
(m_t)_i = \beta_1 (m_{t-1})_i + (1-\beta_1)(\nabla L(W_t))_i,\\
@@ -181,7 +181,7 @@ Kingma et al. [1] proposed to use $$\beta_1 = 0.9, \beta_2 = 0.999, \varepsilon
181181

182182
### NAG
183183

184-
**Nesterov's accelerated gradient** (`solver_type: NESTEROV`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$.
184+
**Nesterov's accelerated gradient** (`type: "Nesterov"`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$.
185185
Though the required assumptions to achieve the $$ \mathcal{O}(1/t^2) $$ convergence typically will not hold for deep networks trained with Caffe (e.g., due to non-smoothness and non-convexity), in practice NAG can be a very effective method for optimizing certain types of deep learning architectures, as demonstrated for deep MNIST autoencoders by Sutskever et al. [2].
186186

187187
The weight update formulas look very similar to the SGD updates given above:
@@ -206,10 +206,10 @@ What distinguishes the method from SGD is the weight setting $$ W $$ on which we
206206

207207
### RMSprop
208208

209-
The **RMSprop** (`solver_type: RMSPROP`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are
209+
The **RMSprop** (`type: "RMSProp"`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are
210210

211211
$$
212-
(v_t)_i =
212+
(v_t)_i =
213213
\begin{cases}
214214
(v_{t-1})_i + \delta, &(\nabla L(W_t))_i(\nabla L(W_{t-1}))_i > 0\\
215215
(v_{t-1})_i \cdot (1-\delta), & \text{else}

examples/mnist/lenet_adadelta_solver.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ snapshot: 5000
2020
snapshot_prefix: "examples/mnist/lenet_adadelta"
2121
# solver mode: CPU or GPU
2222
solver_mode: GPU
23-
solver_type: ADADELTA
23+
type: "AdaDelta"
2424
delta: 1e-6

examples/mnist/lenet_solver_adam.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ max_iter: 10000
2222
snapshot: 5000
2323
snapshot_prefix: "examples/mnist/lenet"
2424
# solver mode: CPU or GPU
25-
solver_type: ADAM
25+
type: "Adam"
2626
solver_mode: GPU

examples/mnist/lenet_solver_rmsprop.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ snapshot: 5000
2323
snapshot_prefix: "examples/mnist/lenet_rmsprop"
2424
# solver mode: CPU or GPU
2525
solver_mode: GPU
26-
solver_type: RMSPROP
26+
type: "RMSProp"
2727
rms_decay: 0.98

examples/mnist/mnist_autoencoder_solver_adadelta.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ snapshot: 10000
1616
snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train"
1717
# solver mode: CPU or GPU
1818
solver_mode: GPU
19-
solver_type: ADADELTA
19+
type: "AdaDelta"

examples/mnist/mnist_autoencoder_solver_adagrad.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ snapshot: 10000
1414
snapshot_prefix: "examples/mnist/mnist_autoencoder_adagrad_train"
1515
# solver mode: CPU or GPU
1616
solver_mode: GPU
17-
solver_type: ADAGRAD
17+
type: "AdaGrad"

examples/mnist/mnist_autoencoder_solver_nesterov.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train"
1717
momentum: 0.95
1818
# solver mode: CPU or GPU
1919
solver_mode: GPU
20-
solver_type: NESTEROV
20+
type: "Nesterov"

include/caffe/caffe.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "caffe/solver_factory.hpp"
1717
#include "caffe/util/benchmark.hpp"
1818
#include "caffe/util/io.hpp"
19+
#include "caffe/util/upgrade_proto.hpp"
1920
#include "caffe/vision_layers.hpp"
2021

2122
#endif // CAFFE_CAFFE_HPP_

include/caffe/solver.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,6 @@ class WorkerSolver : public Solver<Dtype> {
153153
}
154154
};
155155

156-
// The solver factory function
157-
template <typename Dtype>
158-
Solver<Dtype>* GetSolver(const SolverParameter& param);
159-
160156
} // namespace caffe
161157

162158
#endif // CAFFE_SOLVER_HPP_

include/caffe/solver_factory.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Solver;
5353
template <typename Dtype>
5454
class SolverRegistry {
5555
public:
56-
typedef shared_ptr<Solver<Dtype> > (*Creator)(const SolverParameter&);
56+
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
5757
typedef std::map<string, Creator> CreatorRegistry;
5858

5959
static CreatorRegistry& Registry() {
@@ -70,7 +70,7 @@ class SolverRegistry {
7070
}
7171

7272
// Get a solver using a SolverParameter.
73-
static shared_ptr<Solver<Dtype> > CreateSolver(const SolverParameter& param) {
73+
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
7474
const string& type = param.type();
7575
CreatorRegistry& registry = Registry();
7676
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
@@ -112,7 +112,7 @@ template <typename Dtype>
112112
class SolverRegisterer {
113113
public:
114114
SolverRegisterer(const string& type,
115-
shared_ptr<Solver<Dtype> > (*creator)(const SolverParameter&)) {
115+
Solver<Dtype>* (*creator)(const SolverParameter&)) {
116116
// LOG(INFO) << "Registering solver type: " << type;
117117
SolverRegistry<Dtype>::AddCreator(type, creator);
118118
}
@@ -125,10 +125,10 @@ class SolverRegisterer {
125125

126126
#define REGISTER_SOLVER_CLASS(type) \
127127
template <typename Dtype> \
128-
shared_ptr<Solver<Dtype> > Creator_##type##Solver( \
128+
Solver<Dtype>* Creator_##type##Solver( \
129129
const SolverParameter& param) \
130130
{ \
131-
return shared_ptr<Solver<Dtype> >(new type##Solver<Dtype>(param)); \
131+
return new type##Solver<Dtype>(param); \
132132
} \
133133
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
134134

include/caffe/util/upgrade_proto.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ void ReadNetParamsFromTextFileOrDie(const string& param_file,
5959
void ReadNetParamsFromBinaryFileOrDie(const string& param_file,
6060
NetParameter* param);
6161

62+
// Return true iff the solver contains any old solver_type specified as enums
63+
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param);
64+
65+
bool UpgradeSolverType(SolverParameter* solver_param);
66+
67+
// Check for deprecations and upgrade the SolverParameter as needed.
68+
bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param);
69+
70+
// Read parameters from a file into a SolverParameter proto message.
71+
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
72+
SolverParameter* param);
73+
6274
} // namespace caffe
6375

6476
#endif // CAFFE_UTIL_UPGRADE_PROTO_H_

matlab/+caffe/private/caffe_.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ static void get_solver(MEX_ARGS) {
188188
"Usage: caffe_('get_solver', solver_file)");
189189
char* solver_file = mxArrayToString(prhs[0]);
190190
mxCHECK_FILE_EXIST(solver_file);
191-
shared_ptr<Solver<float> > solver(new caffe::SGDSolver<float>(solver_file));
191+
SolverParameter solver_param;
192+
ReadSolverParamsFromTextFileOrDie(solver_file, &solver_param);
193+
shared_ptr<Solver<float> > solver(
194+
SolverRegistry<float>::CreateSolver(solver_param));
192195
solvers_.push_back(solver);
193196
plhs[0] = ptr_to_handle<Solver<float> >(solver.get());
194197
mxFree(solver_file);

python/caffe/_caffe.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
134134

135135
Solver<Dtype>* GetSolverFromFile(const string& filename) {
136136
SolverParameter param;
137-
ReadProtoFromTextFileOrDie(filename, &param);
138-
return GetSolver<Dtype>(param);
137+
ReadSolverParamsFromTextFileOrDie(filename, &param);
138+
return SolverRegistry<Dtype>::CreateSolver(param);
139139
}
140140

141141
struct NdarrayConverterGenerator {

src/caffe/solver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
3636
: net_(), callbacks_(), root_solver_(root_solver),
3737
requested_early_exit_(false) {
3838
SolverParameter param;
39-
ReadProtoFromTextFileOrDie(param_file, &param);
39+
ReadSolverParamsFromTextFileOrDie(param_file, &param);
4040
Init(param);
4141
}
4242

src/caffe/test/test_upgrade_proto.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2928,4 +2928,65 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) {
29282928
}
29292929
}
29302930
#endif // USE_OPENCV
2931+
2932+
class SolverTypeUpgradeTest : public ::testing::Test {
2933+
protected:
2934+
void RunSolverTypeUpgradeTest(
2935+
const string& input_param_string, const string& output_param_string) {
2936+
// Test upgrading old solver_type field (enum) to new type field (string)
2937+
SolverParameter input_param;
2938+
CHECK(google::protobuf::TextFormat::ParseFromString(
2939+
input_param_string, &input_param));
2940+
SolverParameter expected_output_param;
2941+
CHECK(google::protobuf::TextFormat::ParseFromString(
2942+
output_param_string, &expected_output_param));
2943+
SolverParameter actual_output_param = input_param;
2944+
UpgradeSolverType(&actual_output_param);
2945+
EXPECT_EQ(expected_output_param.DebugString(),
2946+
actual_output_param.DebugString());
2947+
}
2948+
};
2949+
2950+
TEST_F(SolverTypeUpgradeTest, TestSimple) {
2951+
const char* old_type_vec[6] = { "SGD", "ADAGRAD", "NESTEROV", "RMSPROP",
2952+
"ADADELTA", "ADAM" };
2953+
const char* new_type_vec[6] = { "SGD", "AdaGrad", "Nesterov", "RMSProp",
2954+
"AdaDelta", "Adam" };
2955+
for (int i = 0; i < 6; ++i) {
2956+
const string& input_proto =
2957+
"net: 'examples/mnist/lenet_train_test.prototxt' "
2958+
"test_iter: 100 "
2959+
"test_interval: 500 "
2960+
"base_lr: 0.01 "
2961+
"momentum: 0.0 "
2962+
"weight_decay: 0.0005 "
2963+
"lr_policy: 'inv' "
2964+
"gamma: 0.0001 "
2965+
"power: 0.75 "
2966+
"display: 100 "
2967+
"max_iter: 10000 "
2968+
"snapshot: 5000 "
2969+
"snapshot_prefix: 'examples/mnist/lenet_rmsprop' "
2970+
"solver_mode: GPU "
2971+
"solver_type: " + std::string(old_type_vec[i]) + " ";
2972+
const string& expected_output_proto =
2973+
"net: 'examples/mnist/lenet_train_test.prototxt' "
2974+
"test_iter: 100 "
2975+
"test_interval: 500 "
2976+
"base_lr: 0.01 "
2977+
"momentum: 0.0 "
2978+
"weight_decay: 0.0005 "
2979+
"lr_policy: 'inv' "
2980+
"gamma: 0.0001 "
2981+
"power: 0.75 "
2982+
"display: 100 "
2983+
"max_iter: 10000 "
2984+
"snapshot: 5000 "
2985+
"snapshot_prefix: 'examples/mnist/lenet_rmsprop' "
2986+
"solver_mode: GPU "
2987+
"type: '" + std::string(new_type_vec[i]) + "' ";
2988+
this->RunSolverTypeUpgradeTest(input_proto, expected_output_proto);
2989+
}
2990+
}
2991+
29312992
} // NOLINT(readability/fn_size) // namespace caffe

src/caffe/util/upgrade_proto.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,4 +937,78 @@ void ReadNetParamsFromBinaryFileOrDie(const string& param_file,
937937
UpgradeNetAsNeeded(param_file, param);
938938
}
939939

940+
// Return true iff the solver contains any old solver_type specified as enums
941+
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) {
942+
if (solver_param.has_solver_type()) {
943+
return true;
944+
}
945+
return false;
946+
}
947+
948+
bool UpgradeSolverType(SolverParameter* solver_param) {
949+
CHECK(!solver_param->has_solver_type() || !solver_param->has_type())
950+
<< "Failed to upgrade solver: old solver_type field (enum) and new type "
951+
<< "field (string) cannot be both specified in solver proto text.";
952+
if (solver_param->has_solver_type()) {
953+
string type;
954+
switch (solver_param->solver_type()) {
955+
case SolverParameter_SolverType_SGD:
956+
type = "SGD";
957+
break;
958+
case SolverParameter_SolverType_NESTEROV:
959+
type = "Nesterov";
960+
break;
961+
case SolverParameter_SolverType_ADAGRAD:
962+
type = "AdaGrad";
963+
break;
964+
case SolverParameter_SolverType_RMSPROP:
965+
type = "RMSProp";
966+
break;
967+
case SolverParameter_SolverType_ADADELTA:
968+
type = "AdaDelta";
969+
break;
970+
case SolverParameter_SolverType_ADAM:
971+
type = "Adam";
972+
break;
973+
default:
974+
LOG(FATAL) << "Unknown SolverParameter solver_type: " << type;
975+
}
976+
solver_param->set_type(type);
977+
solver_param->clear_solver_type();
978+
} else {
979+
LOG(ERROR) << "Warning: solver type already up to date. ";
980+
return false;
981+
}
982+
return true;
983+
}
984+
985+
// Check for deprecations and upgrade the SolverParameter as needed.
986+
bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
987+
bool success = true;
988+
// Try to upgrade old style solver_type enum fields into new string type
989+
if (SolverNeedsTypeUpgrade(*param)) {
990+
LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
991+
<< "'solver_type' field (enum)': " << param_file;
992+
if (!UpgradeSolverType(param)) {
993+
success = false;
994+
LOG(ERROR) << "Warning: had one or more problems upgrading "
995+
<< "SolverType (see above).";
996+
} else {
997+
LOG(INFO) << "Successfully upgraded file specified using deprecated "
998+
<< "'solver_type' field (enum) to 'type' field (string).";
999+
LOG(WARNING) << "Note that future Caffe releases will only support "
1000+
<< "'type' field (string) for a solver's type.";
1001+
}
1002+
}
1003+
return success;
1004+
}
1005+
1006+
// Read parameters from a file into a SolverParameter proto message.
1007+
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
1008+
SolverParameter* param) {
1009+
CHECK(ReadProtoFromTextFile(param_file, param))
1010+
<< "Failed to parse SolverParameter file: " << param_file;
1011+
UpgradeSolverAsNeeded(param_file, param);
1012+
}
1013+
9401014
} // namespace caffe

tools/caffe.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ int train() {
157157
"but not both.";
158158

159159
caffe::SolverParameter solver_param;
160-
caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param);
160+
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
161161

162162
// If the gpus flag is not provided, allow the mode and device to be set
163163
// in the solver prototxt.

0 commit comments

Comments
 (0)