Skip to content

Commit 2999c79

Browse files
graham63benson31
authored andcommitted
added binary operators to fill_onnx_node, refactored binary_with constant for consistency with location of operator in vector for regular binary operators
1 parent 54a2ff5 commit 2999c79

File tree

2 files changed

+75
-67
lines changed

2 files changed

+75
-67
lines changed

include/lbann/operators/math/binary_with_constant.hpp

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
180180
AddConstantOperator<T, D> const op)
181181
{
182182
std::vector<onnx::NodeProto> nodes(2UL);
183-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
184-
nodes.front().set_op_type("PostConstant");
185-
nodes.back().set_op_type("Add");
183+
nodes.front().set_op_type("Add");
184+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
185+
nodes.back().set_op_type("PostConstant");
186186
return nodes;
187187
}
188188

@@ -191,9 +191,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
191191
ScaleOperator<T, D> const op)
192192
{
193193
std::vector<onnx::NodeProto> nodes(2UL);
194-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
195-
nodes.front().set_op_type("PostConstant");
196-
nodes.back().set_op_type("Mul");
194+
nodes.front().set_op_type("Mul");
195+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
196+
nodes.back().set_op_type("PostConstant");
197197
return nodes;
198198
}
199199

@@ -202,9 +202,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
202202
SubtractConstantOperator<T, D> const op)
203203
{
204204
std::vector<onnx::NodeProto> nodes(2UL);
205-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
206-
nodes.front().set_op_type("PostConstant");
207-
nodes.back().set_op_type("Sub");
205+
nodes.front().set_op_type("Sub");
206+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
207+
nodes.back().set_op_type("PostConstant");
208208
return nodes;
209209
}
210210

@@ -213,9 +213,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
213213
ConstantSubtractOperator<T, D> const op)
214214
{
215215
std::vector<onnx::NodeProto> nodes(2UL);
216-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
217-
nodes.front().set_op_type("PreConstant");
218-
nodes.back().set_op_type("Sub");
216+
nodes.front().set_op_type("Sub");
217+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
218+
nodes.back().set_op_type("PreConstant");
219219
return nodes;
220220
}
221221

@@ -224,9 +224,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
224224
MaxConstantOperator<T, D> const op)
225225
{
226226
std::vector<onnx::NodeProto> nodes(2UL);
227-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
228-
nodes.front().set_op_type("PreConstant");
229-
nodes.back().set_op_type("Max");
227+
nodes.front().set_op_type("Max");
228+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
229+
nodes.back().set_op_type("PreConstant");
230230
return nodes;
231231
}
232232

@@ -235,9 +235,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
235235
MinConstantOperator<T, D> const op)
236236
{
237237
std::vector<onnx::NodeProto> nodes(2UL);
238-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
239-
nodes.front().set_op_type("PreConstant");
240-
nodes.back().set_op_type("Min");
238+
nodes.front().set_op_type("Min");
239+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
240+
nodes.back().set_op_type("PreConstant");
241241
return nodes;
242242
}
243243

@@ -246,9 +246,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
246246
EqualConstantOperator<T, D> const op)
247247
{
248248
std::vector<onnx::NodeProto> nodes(2UL);
249-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
250-
nodes.front().set_op_type("PreConstant");
251-
nodes.back().set_op_type("Equal");
249+
nodes.front().set_op_type("Equal");
250+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
251+
nodes.back().set_op_type("PreConstant");
252252
return nodes;
253253
}
254254

@@ -257,10 +257,10 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
257257
NotEqualConstantOperator<T, D> const op)
258258
{
259259
std::vector<onnx::NodeProto> nodes(3UL);
260-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
261-
nodes.front().set_op_type("PreConstant");
260+
nodes.front().set_op_type("Equal");
261+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
262+
nodes.back().set_op_type("PreConstant");
262263
nodes.at(1).set_op_type("Not");
263-
nodes.back().set_op_type("Equal");
264264
return nodes;
265265
}
266266

@@ -269,9 +269,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
269269
LessConstantOperator<T, D> const op)
270270
{
271271
std::vector<onnx::NodeProto> nodes(2UL);
272-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
273-
nodes.front().set_op_type("PostConstant");
274-
nodes.back().set_op_type("Less");
272+
nodes.front().set_op_type("Less");
273+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
274+
nodes.back().set_op_type("PostConstant");
275275
return nodes;
276276
}
277277

@@ -280,9 +280,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
280280
LessEqualConstantOperator<T, D> const op)
281281
{
282282
std::vector<onnx::NodeProto> nodes(2UL);
283-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
284-
nodes.front().set_op_type("PostConstant");
285-
nodes.back().set_op_type("LessOrEqual");
283+
nodes.front().set_op_type("LessOrEqual");
284+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
285+
nodes.back().set_op_type("PostConstant");
286286
return nodes;
287287
}
288288

@@ -291,9 +291,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
291291
GreaterConstantOperator<T, D> const op)
292292
{
293293
std::vector<onnx::NodeProto> nodes(2UL);
294-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
295-
nodes.front().set_op_type("PreConstant");
296-
nodes.back().set_op_type("Greater");
294+
nodes.front().set_op_type("Greater");
295+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
296+
nodes.back().set_op_type("PreConstant");
297297
return nodes;
298298
}
299299

@@ -302,9 +302,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
302302
GreaterEqualConstantOperator<T, D> const op)
303303
{
304304
std::vector<onnx::NodeProto> nodes(2UL);
305-
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
306-
nodes.front().set_op_type("PreConstant");
307-
nodes.back().set_op_type("GreaterOrEqual");
305+
nodes.front().set_op_type("GreaterOrEqual");
306+
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
307+
nodes.back().set_op_type("PreConstant");
308308
return nodes;
309309
}
310310

src/layers/operator_layer.cpp

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,57 +48,65 @@ template <typename T, typename O, data_layout L, El::Device D>
4848
void OperatorLayer<T, O, L, D>::fill_onnx_node(
4949
onnx::GraphProto& graph) const
5050
{
51-
std::vector<onnx::NodeProto> nodes(2UL);
52-
nodes.front().add_attribute()->set_type(onnx::AttributeProto::FLOAT);
53-
nodes.front().add_attribute()->set_f(El::To<float>(5));
54-
nodes.front().set_op_type("PostConstant");
55-
nodes.back().set_op_type("Add");
51+
const auto& parents = this->get_parent_layers();
52+
auto nodes = m_ops.front()->get_onnx_nodes();
5653

57-
//OperatorPtr op;
58-
//auto nodes = op->get_onnx_nodes();
59-
const auto* parent = this->get_parent_layers()[0];
54+
auto* op_node = graph.add_node();
55+
*op_node = nodes.front();
6056

61-
auto* const_node = graph.add_node();
62-
*const_node = nodes.front();
57+
op_node->set_name(this->get_name());
58+
op_node->set_domain("");
59+
op_node->set_doc_string(this->get_name());
6360

64-
auto* node = graph.add_node();
65-
*node = nodes.back();
66-
node->set_name(this->get_name());
67-
node->set_domain("");
68-
node->set_doc_string(this->get_name());
69-
if(const_node->op_type() == "PostConstant")
61+
//binary operators
62+
if(nodes.size() == 1)
7063
{
71-
node->add_input(parent->get_name() + "_0");
72-
node->add_input(const_node->output(0));
73-
const_node->set_op_type("Constant");
64+
for(auto* parent : parents)
65+
{
66+
size_t idx = parent->find_child_layer_index(*this);
67+
op_node->add_input(parent->get_name() + "_" + std::to_string(idx));
68+
}
7469
}
75-
else if(const_node->op_type() == "PreConstant")
70+
// Binary w/ constant operators
71+
else if(nodes.size() == 2 || nodes.size() == 3)
7672
{
77-
node->add_input(const_node->output(0));
78-
node->add_input(parent->get_name() + "_0");
73+
auto* const_node = graph.add_node();
74+
*const_node = nodes.back();
75+
if(const_node->op_type() == "PostConstant")
76+
{
77+
op_node->add_input(parents[0]->get_name() + "_0");
78+
op_node->add_input(const_node->output(0));
79+
}
80+
else if(const_node->op_type() == "PreConstant")
81+
{
82+
op_node->add_input(const_node->output(0));
83+
op_node->add_input(parents[0]->get_name() + "_0");
84+
}
85+
else
86+
LBANN_ERROR("Unknown onnx op type for constant.");
87+
7988
const_node->set_op_type("Constant");
8089
}
8190
else
82-
LBANN_ERROR("Unknown onnx op type for constant.");
91+
LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", nodes.size());
8392

8493
// Not equal operator
8594
if(nodes.size() == 3)
8695
{
87-
node->add_output("EqualOperator");
96+
op_node->add_output("EqualOperator");
8897
auto* not_node = graph.add_node();
89-
not_node->add_input(node->output(0));
90-
not_node->add_output(this->get_child_layers()[0]->get_name() + "_0");
98+
not_node->add_input(op_node->output(0));
9199
not_node->set_name("Not operator");
92100
not_node->set_op_type("Not");
93101
not_node->set_domain("");
94102
not_node->set_doc_string("Not node for not equal operation.");
103+
op_node = not_node;
95104
}
96-
else if(nodes.size() == 2)
97-
{
98-
node->add_output(this->get_child_layers()[0]->get_name() + "_0");
105+
106+
for (auto const* child : this->get_child_layers()) {
107+
auto idx = this->find_child_layer_index(*child);
108+
op_node->add_output(this->get_name() + "_" + std::to_string(idx));
99109
}
100-
else
101-
LBANN_ERROR("Expected two or three nodes for binary constant operation, received ", nodes.size());
102110
}
103111
#endif // LBANN_HAS_ONNX
104112

0 commit comments

Comments
 (0)