Skip to content

Commit 1439ec5

Browse files
committed
added fill_onnx_node() to operator layer
1 parent 97ef306 commit 1439ec5

File tree

3 files changed

+135
-23
lines changed

3 files changed

+135
-23
lines changed

include/lbann/layers/operator_layer.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ class OperatorLayer final : public data_type_layer<InputT, OutputT>
8181
data_layout get_data_layout() const final;
8282
El::Device get_device_allocation() const final;
8383

84+
#ifdef LBANN_HAS_ONNX
85+
void fill_onnx_node(onnx::GraphProto& graph) const override;
86+
#endif //LBANN_HAS_ONNX
87+
8488
void fp_compute() final;
8589
void bp_compute() final;
8690

include/lbann/operators/math/binary_with_constant.hpp

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
#include "lbann/operators/elementwise_operator.hpp"
3333
#include "lbann/utils/cloneable.hpp"
3434

35+
#ifdef LBANN_HAS_ONNX
3536
#include <onnx/onnx-ml.pb.h>
36-
#include <operators.pb.h>
37+
#endif // LBANN_HAS_ONNX
3738

3839
/** @file
3940
*
@@ -164,6 +165,9 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant,
164165
inline onnx::NodeProto get_constant_node(float val)
165166
{
166167
onnx::NodeProto const_node;
168+
const_node.add_output("const_val");
169+
const_node.set_domain("");
170+
const_node.set_doc_string("Const value for binary with constant operations");
167171
auto* const_val = const_node.add_attribute();
168172
const_val->set_name("value_float");
169173
const_val->set_type(onnx::AttributeProto::FLOAT);
@@ -184,79 +188,124 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
184188

185189
template <typename T, El::Device D>
186190
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
187-
ScaleOperator<T, D> const)
191+
ScaleOperator<T, D> const op)
188192
{
189-
return {};
193+
std::vector<onnx::NodeProto> nodes(2UL);
194+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
195+
nodes.front().set_op_type("PostConstant");
196+
nodes.back().set_op_type("Mul");
197+
return nodes;
190198
}
191199

192200
template <typename T, El::Device D>
193201
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
194-
SubtractConstantOperator<T, D> const)
202+
SubtractConstantOperator<T, D> const op)
195203
{
196-
return {};
204+
std::vector<onnx::NodeProto> nodes(2UL);
205+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
206+
nodes.front().set_op_type("PostConstant");
207+
nodes.back().set_op_type("Sub");
208+
return nodes;
197209
}
198210

199211
template <typename T, El::Device D>
200212
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
201-
ConstantSubtractOperator<T, D> const)
213+
ConstantSubtractOperator<T, D> const op)
202214
{
203-
return {};
215+
std::vector<onnx::NodeProto> nodes(2UL);
216+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
217+
nodes.front().set_op_type("PreConstant");
218+
nodes.back().set_op_type("Sub");
219+
return nodes;
204220
}
205221

206222
template <typename T, El::Device D>
207223
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
208-
MaxConstantOperator<T, D> const)
224+
MaxConstantOperator<T, D> const op)
209225
{
210-
return {};
226+
std::vector<onnx::NodeProto> nodes(2UL);
227+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
228+
nodes.front().set_op_type("PreConstant");
229+
nodes.back().set_op_type("Max");
230+
return nodes;
211231
}
212232

213233
template <typename T, El::Device D>
214234
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
215-
MinConstantOperator<T, D> const)
235+
MinConstantOperator<T, D> const op)
216236
{
217-
return {};
237+
std::vector<onnx::NodeProto> nodes(2UL);
238+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
239+
nodes.front().set_op_type("PreConstant");
240+
nodes.back().set_op_type("Min");
241+
return nodes;
218242
}
219243

220244
template <typename T, El::Device D>
221245
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
222-
EqualConstantOperator<T, D> const)
246+
EqualConstantOperator<T, D> const op)
223247
{
224-
return {};
248+
std::vector<onnx::NodeProto> nodes(2UL);
249+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
250+
nodes.front().set_op_type("PreConstant");
251+
nodes.back().set_op_type("Equal");
252+
return nodes;
225253
}
226254

227255
template <typename T, El::Device D>
228256
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
229-
NotEqualConstantOperator<T, D> const)
257+
NotEqualConstantOperator<T, D> const op)
230258
{
231-
return {};
259+
std::vector<onnx::NodeProto> nodes(3UL);
260+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
261+
nodes.front().set_op_type("PreConstant");
262+
nodes.at(1).set_op_type("Not");
263+
nodes.back().set_op_type("Equal");
264+
return nodes;
232265
}
233266

234267
template <typename T, El::Device D>
235268
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
236-
LessConstantOperator<T, D> const)
269+
LessConstantOperator<T, D> const op)
237270
{
238-
return {};
271+
std::vector<onnx::NodeProto> nodes(2UL);
272+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
273+
nodes.front().set_op_type("PostConstant");
274+
nodes.back().set_op_type("Less");
275+
return nodes;
239276
}
240277

241278
template <typename T, El::Device D>
242279
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
243-
LessEqualConstantOperator<T, D> const)
280+
LessEqualConstantOperator<T, D> const op)
244281
{
245-
return {};
282+
std::vector<onnx::NodeProto> nodes(2UL);
283+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
284+
nodes.front().set_op_type("PostConstant");
285+
nodes.back().set_op_type("LessOrEqual");
286+
return nodes;
246287
}
247288

248289
template <typename T, El::Device D>
249290
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
250-
GreaterConstantOperator<T, D> const)
291+
GreaterConstantOperator<T, D> const op)
251292
{
252-
return {};
293+
std::vector<onnx::NodeProto> nodes(2UL);
294+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
295+
nodes.front().set_op_type("PreConstant");
296+
nodes.back().set_op_type("Greater");
297+
return nodes;
253298
}
254299

255300
template <typename T, El::Device D>
256301
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
257-
GreaterEqualConstantOperator<T, D> const)
302+
GreaterEqualConstantOperator<T, D> const op)
258303
{
259-
return {};
304+
std::vector<onnx::NodeProto> nodes(2UL);
305+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
306+
nodes.front().set_op_type("PreConstant");
307+
nodes.back().set_op_type("GreaterOrEqual");
308+
return nodes;
260309
}
261310

262311
} // namespace lbann

src/layers/operator_layer.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,63 @@ namespace lbann {
4343

4444
#include "lbann/macros/instantiate_device.hpp"
4545

46+
#ifdef LBANN_HAS_ONNX
47+
template <typename T, typename O, data_layout L, El::Device D>
48+
void OperatorLayer<T, O, L, D>::fill_onnx_node(
49+
onnx::GraphProto& graph) const
50+
{
51+
std::vector<onnx::NodeProto> nodes(2UL);
52+
nodes.front().add_attribute()->set_type(onnx::AttributeProto::FLOAT);
53+
nodes.front().add_attribute()->set_f(El::To<float>(5));
54+
nodes.front().set_op_type("PostConstant");
55+
nodes.back().set_op_type("Add");
56+
57+
//OperatorPtr op;
58+
//auto nodes = op->get_onnx_nodes();
59+
const auto* parent = this->get_parent_layers()[0];
60+
61+
auto* const_node = graph.add_node();
62+
*const_node = nodes.front();
63+
64+
auto* node = graph.add_node();
65+
*node = nodes.back();
66+
node->set_name(this->get_name());
67+
node->set_domain("");
68+
node->set_doc_string(this->get_name());
69+
if(const_node->op_type() == "PostConstant")
70+
{
71+
node->add_input(parent->get_name() + "_0");
72+
node->add_input(const_node->output(0));
73+
const_node->set_op_type("Constant");
74+
}
75+
else if(const_node->op_type() == "PreConstant")
76+
{
77+
node->add_input(const_node->output(0));
78+
node->add_input(parent->get_name() + "_0");
79+
const_node->set_op_type("Constant");
80+
}
81+
else
82+
LBANN_ERROR("Unknown onnx op type for constant.");
83+
84+
// Not equal operator
85+
if(nodes.size() == 3)
86+
{
87+
node->add_output("EqualOperator");
88+
auto* not_node = graph.add_node();
89+
not_node->add_input(node->output(0));
90+
not_node->add_output(this->get_child_layers()[0]->get_name() + "_0");
91+
not_node->set_name("Not operator");
92+
not_node->set_op_type("Not");
93+
not_node->set_domain("");
94+
not_node->set_doc_string("Not node for not equal operation.");
95+
}
96+
else if(nodes.size() == 2)
97+
{
98+
node->add_output(this->get_child_layers()[0]->get_name() + "_0");
99+
}
100+
else
101+
LBANN_ERROR("Expected two or three nodes for binary constant operation, received ", nodes.size());
102+
}
103+
#endif // LBANN_HAS_ONNX
104+
46105
} // namespace lbann

0 commit comments

Comments
 (0)