32
32
#include " lbann/operators/elementwise_operator.hpp"
33
33
#include " lbann/utils/cloneable.hpp"
34
34
35
+ #ifdef LBANN_HAS_ONNX
35
36
#include < onnx/onnx-ml.pb.h>
36
- #include < operators.pb.h >
37
+ #endif // LBANN_HAS_ONNX
37
38
38
39
/* * @file
39
40
*
@@ -164,6 +165,9 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant,
164
165
inline onnx::NodeProto get_constant_node (float val)
165
166
{
166
167
onnx::NodeProto const_node;
168
+ const_node.add_output (" const_val" );
169
+ const_node.set_domain (" " );
170
+ const_node.set_doc_string (" Const value for binary with constant operations" );
167
171
auto * const_val = const_node.add_attribute ();
168
172
const_val->set_name (" value_float" );
169
173
const_val->set_type (onnx::AttributeProto::FLOAT);
@@ -184,79 +188,124 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
184
188
185
189
template <typename T, El::Device D>
186
190
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
187
- ScaleOperator<T, D> const )
191
+ ScaleOperator<T, D> const op )
188
192
{
189
- return {};
193
+ std::vector<onnx::NodeProto> nodes (2UL );
194
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
195
+ nodes.front ().set_op_type (" PostConstant" );
196
+ nodes.back ().set_op_type (" Mul" );
197
+ return nodes;
190
198
}
191
199
192
200
template <typename T, El::Device D>
193
201
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
194
- SubtractConstantOperator<T, D> const )
202
+ SubtractConstantOperator<T, D> const op )
195
203
{
196
- return {};
204
+ std::vector<onnx::NodeProto> nodes (2UL );
205
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
206
+ nodes.front ().set_op_type (" PostConstant" );
207
+ nodes.back ().set_op_type (" Sub" );
208
+ return nodes;
197
209
}
198
210
199
211
template <typename T, El::Device D>
200
212
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
201
- ConstantSubtractOperator<T, D> const )
213
+ ConstantSubtractOperator<T, D> const op )
202
214
{
203
- return {};
215
+ std::vector<onnx::NodeProto> nodes (2UL );
216
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
217
+ nodes.front ().set_op_type (" PreConstant" );
218
+ nodes.back ().set_op_type (" Sub" );
219
+ return nodes;
204
220
}
205
221
206
222
template <typename T, El::Device D>
207
223
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
208
- MaxConstantOperator<T, D> const )
224
+ MaxConstantOperator<T, D> const op )
209
225
{
210
- return {};
226
+ std::vector<onnx::NodeProto> nodes (2UL );
227
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
228
+ nodes.front ().set_op_type (" PreConstant" );
229
+ nodes.back ().set_op_type (" Max" );
230
+ return nodes;
211
231
}
212
232
213
233
template <typename T, El::Device D>
214
234
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
215
- MinConstantOperator<T, D> const )
235
+ MinConstantOperator<T, D> const op )
216
236
{
217
- return {};
237
+ std::vector<onnx::NodeProto> nodes (2UL );
238
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
239
+ nodes.front ().set_op_type (" PreConstant" );
240
+ nodes.back ().set_op_type (" Min" );
241
+ return nodes;
218
242
}
219
243
220
244
template <typename T, El::Device D>
221
245
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
222
- EqualConstantOperator<T, D> const )
246
+ EqualConstantOperator<T, D> const op )
223
247
{
224
- return {};
248
+ std::vector<onnx::NodeProto> nodes (2UL );
249
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
250
+ nodes.front ().set_op_type (" PreConstant" );
251
+ nodes.back ().set_op_type (" Equal" );
252
+ return nodes;
225
253
}
226
254
227
255
template <typename T, El::Device D>
228
256
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
229
- NotEqualConstantOperator<T, D> const )
257
+ NotEqualConstantOperator<T, D> const op )
230
258
{
231
- return {};
259
+ std::vector<onnx::NodeProto> nodes (3UL );
260
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
261
+ nodes.front ().set_op_type (" PreConstant" );
262
+ nodes.at (1 ).set_op_type (" Not" );
263
+ nodes.back ().set_op_type (" Equal" );
264
+ return nodes;
232
265
}
233
266
234
267
template <typename T, El::Device D>
235
268
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
236
- LessConstantOperator<T, D> const )
269
+ LessConstantOperator<T, D> const op )
237
270
{
238
- return {};
271
+ std::vector<onnx::NodeProto> nodes (2UL );
272
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
273
+ nodes.front ().set_op_type (" PostConstant" );
274
+ nodes.back ().set_op_type (" Less" );
275
+ return nodes;
239
276
}
240
277
241
278
template <typename T, El::Device D>
242
279
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
243
- LessEqualConstantOperator<T, D> const )
280
+ LessEqualConstantOperator<T, D> const op )
244
281
{
245
- return {};
282
+ std::vector<onnx::NodeProto> nodes (2UL );
283
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
284
+ nodes.front ().set_op_type (" PostConstant" );
285
+ nodes.back ().set_op_type (" LessOrEqual" );
286
+ return nodes;
246
287
}
247
288
248
289
template <typename T, El::Device D>
249
290
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
250
- GreaterConstantOperator<T, D> const )
291
+ GreaterConstantOperator<T, D> const op )
251
292
{
252
- return {};
293
+ std::vector<onnx::NodeProto> nodes (2UL );
294
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
295
+ nodes.front ().set_op_type (" PreConstant" );
296
+ nodes.back ().set_op_type (" Greater" );
297
+ return nodes;
253
298
}
254
299
255
300
template <typename T, El::Device D>
256
301
std::vector<onnx::NodeProto> get_onnx_nodes_impl (
257
- GreaterEqualConstantOperator<T, D> const )
302
+ GreaterEqualConstantOperator<T, D> const op )
258
303
{
259
- return {};
304
+ std::vector<onnx::NodeProto> nodes (2UL );
305
+ nodes.front () = get_constant_node (El::To<float >(op.get_constant ()));
306
+ nodes.front ().set_op_type (" PreConstant" );
307
+ nodes.back ().set_op_type (" GreaterOrEqual" );
308
+ return nodes;
260
309
}
261
310
262
311
} // namespace lbann
0 commit comments