Skip to content

Commit 97ef306

Browse files
committed
applied patch to expand operator layer macros
1 parent d267de0 commit 97ef306

File tree

4 files changed

+217
-25
lines changed

4 files changed

+217
-25
lines changed

include/lbann/operators/declare_stateless_op.hpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@
3232

3333
#include <operators.pb.h>
3434

35+
#ifdef LBANN_HAS_ONNX
36+
#define ADD_GET_ONNX_NODES_API() \
37+
std::vector<onnx::NodeProto> get_onnx_nodes() const final \
38+
{ \
39+
return get_onnx_nodes_impl(*this); \
40+
}
41+
#else
42+
#define ADD_GET_ONNX_NODES_API()
43+
#endif // LBANN_HAS_ONNX
44+
3545
// These are all single-type operators.
3646

3747
#define LBANN_DECLARE_STATELESS_OPERATOR(OP_NAME, OP_STRING) \
@@ -53,14 +63,18 @@
5363
OP_NAME##Operator& operator=(OP_NAME##Operator&&) = default; \
5464
OP_NAME##Operator& operator=(OP_NAME##Operator const&) = default; \
5565
~OP_NAME##Operator() = default; \
56-
std::string get_type() const final { return OP_STRING; } \
66+
std::string get_type() const final \
67+
{ \
68+
return OP_STRING; \
69+
} \
5770
template <typename ArchiveT> \
5871
void serialize(ArchiveT& ar) \
5972
{ \
6073
using OperatorType = Operator<DataT, DataT, D>; \
6174
ar(::cereal::make_nvp("Operator", \
6275
::cereal::base_class<OperatorType>(this))); \
6376
} \
77+
ADD_GET_ONNX_NODES_API() \
6478
void fp_compute(std::vector<ConstInputTensorType> const& inputs, \
6579
std::vector<OutputTensorType> const& outputs) const final; \
6680
void bp_compute( \
@@ -73,7 +87,8 @@
7387
{ \
7488
msg.mutable_parameters()->PackFrom(lbann_data::OP_NAME##Operator{}); \
7589
} \
76-
void do_fill_description(description&) const final {} \
90+
void do_fill_description(description&) const final \
91+
{} \
7792
}
7893

7994
#define LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, OP_STRING) \
@@ -98,14 +113,18 @@
98113
OP_NAME##Operator& operator=(OP_NAME##Operator&&) = default; \
99114
OP_NAME##Operator& operator=(OP_NAME##Operator const&) = default; \
100115
~OP_NAME##Operator() = default; \
101-
std::string get_type() const final { return OP_STRING; } \
116+
std::string get_type() const final \
117+
{ \
118+
return OP_STRING; \
119+
} \
102120
template <typename ArchiveT> \
103121
void serialize(ArchiveT& ar) \
104122
{ \
105123
using OperatorType = ElementwiseOperator<DataT, DataT, D>; \
106124
ar(::cereal::make_nvp("ElementwiseOperator", \
107125
::cereal::base_class<OperatorType>(this))); \
108126
} \
127+
ADD_GET_ONNX_NODES_API() \
109128
\
110129
private: \
111130
void \
@@ -119,7 +138,23 @@
119138
{ \
120139
msg.mutable_parameters()->PackFrom(lbann_data::OP_NAME##Operator{}); \
121140
} \
122-
void do_fill_description(description&) const final {} \
141+
void do_fill_description(description&) const final \
142+
{} \
123143
}
124144

145+
namespace lbann {
146+
147+
#ifdef LBANN_HAS_ONNX
148+
// Overloads of this function are used to implement the functions in
149+
// the macro template above.
150+
template <typename OperatorT>
151+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(OperatorT const& op)
152+
{
153+
// The default assumption is that we don't know how to represent
154+
// this operator in ONNX terms yet.
155+
return {};
156+
}
157+
#endif // LBANN_HAS_ONNX
158+
159+
} // namespace lbann
125160
#endif // LBANN_INCLUDE_LBANN_OPERATORS_DECLARE_STATELESS_OP_HPP_INCLUDED

include/lbann/operators/math/binary.hpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,57 @@
2929

3030
#include "lbann/operators/declare_stateless_op.hpp"
3131

32+
#ifdef LBANN_HAS_ONNX
33+
#define LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(OP_NAME, \
34+
OP_STRING, \
35+
OP_ONNX_NAME) \
36+
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, OP_STRING); \
37+
template <typename T, El::Device D> \
38+
std::vector<onnx::NodeProto> get_onnx_nodes_impl( \
39+
OP_NAME##Operator<T, D> const& op) \
40+
{ \
41+
std::vector<onnx::NodeProto> nodes(1UL); \
42+
nodes.front().set_op_type(OP_ONNX_NAME); \
43+
return nodes; \
44+
}
45+
#else
46+
#define LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(OP_NAME, \
47+
OP_STRING, \
48+
OP_ONNX_NAME) \
49+
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, OP_STRING)
50+
#endif // LBANN_HAS_ONNX
51+
3252
namespace lbann {
3353

3454
// Arithmetic operations
35-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Add, "add");
36-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Subtract, "subtract");
37-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Multiply, "multiply");
38-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Divide, "divide");
39-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Mod, "modulo");
40-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Pow, "power");
55+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Add, "add", "Add")
56+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Subtract, "subtract", "Sub")
57+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Multiply, "multiply", "Mul")
58+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Divide, "divide", "Div")
59+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Mod, "modulo", "Mod")
60+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Pow, "power", "Pow")
4161
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(SafeDivide, "safe divide");
4262
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(SquaredDifference,
4363
"squared difference");
4464

4565
// Comparison operations
46-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Max, "maximum");
47-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Min, "minimum");
48-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Equal, "equal");
66+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Max, "maximum", "Max")
67+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Min, "minimum", "Min")
68+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Equal, "equal", "Equal")
4969
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(NotEqual, "not equal");
50-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Less, "less than");
51-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LessEqual, "less than or equal");
52-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(Greater, "greater than");
53-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(GreaterEqual,
54-
"greater than or equal");
70+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Less, "less than", "Less")
71+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LessEqual,
72+
"less than or equal",
73+
"LessOrEqual")
74+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(Greater, "greater than", "Greater")
75+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(GreaterEqual,
76+
"greater than or equal",
77+
"GreaterOrEqual")
5578

5679
// Logical operations
57-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalAnd, "logical and");
58-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalOr, "logical or");
59-
LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(LogicalXor, "logical xor");
80+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalAnd, "logical and", "And")
81+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalOr, "logical or", "Or")
82+
LBANN_DECLARE_STATELESS_EWISE_ONNX_OP(LogicalXor, "logical xor", "Xor")
6083

6184
} // namespace lbann
6285

include/lbann/operators/math/binary_with_constant.hpp

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "lbann/operators/elementwise_operator.hpp"
3333
#include "lbann/utils/cloneable.hpp"
3434

35+
#include <onnx/onnx-ml.pb.h>
3536
#include <operators.pb.h>
3637

3738
/** @file
@@ -50,6 +51,16 @@
5051

5152
#include <operators.pb.h>
5253

54+
#ifdef LBANN_HAS_ONNX
55+
#define ADD_GET_ONNX_NODES_API() \
56+
std::vector<onnx::NodeProto> get_onnx_nodes() const final \
57+
{ \
58+
return get_onnx_nodes_impl(*this); \
59+
}
60+
#else
61+
#define ADD_GET_ONNX_NODES_API()
62+
#endif // LBANN_HAS_ONNX
63+
5364
// These are all single-type operators.
5465

5566
#define LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(OP_NAME, OP_STRING) \
@@ -76,7 +87,10 @@
7687
OP_NAME##Operator& operator=(OP_NAME##Operator&&) = default; \
7788
OP_NAME##Operator& operator=(OP_NAME##Operator const&) = default; \
7889
~OP_NAME##Operator() = default; \
79-
std::string get_type() const final { return OP_STRING; } \
90+
std::string get_type() const final \
91+
{ \
92+
return OP_STRING; \
93+
} \
8094
template <typename ArchiveT> \
8195
void serialize(ArchiveT& ar) \
8296
{ \
@@ -85,7 +99,11 @@
8599
::cereal::base_class<OperatorType>(this)), \
86100
CEREAL_NVP(m_constant)); \
87101
} \
88-
DataT get_constant() const noexcept { return m_constant; } \
102+
ADD_GET_ONNX_NODES_API() \
103+
DataT get_constant() const noexcept \
104+
{ \
105+
return m_constant; \
106+
} \
89107
\
90108
private: \
91109
void \
@@ -117,7 +135,7 @@ namespace lbann {
117135
// x + c -- treated as commutative.
118136
LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(AddConstant, "add constant");
119137

120-
// x + c -- treated as commutative.
138+
// x * c -- treated as commutative.
121139
LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(Scale, "scale");
122140

123141
// x - C -- yes, could be "plus -C", but so could 7-4 be 7+-4, but
@@ -143,5 +161,103 @@ LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterEqualConstant,
143161
LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(GreaterConstant,
144162
"greater than constant");
145163

164+
inline onnx::NodeProto get_constant_node(float val)
165+
{
166+
onnx::NodeProto const_node;
167+
auto* const_val = const_node.add_attribute();
168+
const_val->set_name("value_float");
169+
const_val->set_type(onnx::AttributeProto::FLOAT);
170+
const_val->set_f(val);
171+
return const_node;
172+
}
173+
174+
template <typename T, El::Device D>
175+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
176+
AddConstantOperator<T, D> const op)
177+
{
178+
std::vector<onnx::NodeProto> nodes(2UL);
179+
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
180+
nodes.front().set_op_type("PostConstant");
181+
nodes.back().set_op_type("Add");
182+
return nodes;
183+
}
184+
185+
template <typename T, El::Device D>
186+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
187+
ScaleOperator<T, D> const)
188+
{
189+
return {};
190+
}
191+
192+
template <typename T, El::Device D>
193+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
194+
SubtractConstantOperator<T, D> const)
195+
{
196+
return {};
197+
}
198+
199+
template <typename T, El::Device D>
200+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
201+
ConstantSubtractOperator<T, D> const)
202+
{
203+
return {};
204+
}
205+
206+
template <typename T, El::Device D>
207+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
208+
MaxConstantOperator<T, D> const)
209+
{
210+
return {};
211+
}
212+
213+
template <typename T, El::Device D>
214+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
215+
MinConstantOperator<T, D> const)
216+
{
217+
return {};
218+
}
219+
220+
template <typename T, El::Device D>
221+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
222+
EqualConstantOperator<T, D> const)
223+
{
224+
return {};
225+
}
226+
227+
template <typename T, El::Device D>
228+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
229+
NotEqualConstantOperator<T, D> const)
230+
{
231+
return {};
232+
}
233+
234+
template <typename T, El::Device D>
235+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
236+
LessConstantOperator<T, D> const)
237+
{
238+
return {};
239+
}
240+
241+
template <typename T, El::Device D>
242+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
243+
LessEqualConstantOperator<T, D> const)
244+
{
245+
return {};
246+
}
247+
248+
template <typename T, El::Device D>
249+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
250+
GreaterConstantOperator<T, D> const)
251+
{
252+
return {};
253+
}
254+
255+
template <typename T, El::Device D>
256+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
257+
GreaterEqualConstantOperator<T, D> const)
258+
{
259+
return {};
260+
}
261+
146262
} // namespace lbann
147263
#endif // LBANN_INCLUDE_LBANN_OPERATORS_BINARY_WITH_CONSTANT_HPP_INCLUDED

include/lbann/operators/operator.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343

4444
#include <google/protobuf/message.h>
4545

46+
#ifdef LBANN_HAS_ONNX
47+
#include <onnx/onnx_pb.h>
48+
#endif
49+
4650
#include <string>
4751
#include <vector>
4852

@@ -129,6 +133,10 @@ class Operator : public AbstractCloneableBase<Operator<InputT, OutputT, D>>,
129133
template <typename ArchiveT>
130134
void serialize(ArchiveT& ar);
131135

136+
#ifdef LBANN_HAS_ONNX
137+
virtual std::vector<onnx::NodeProto> get_onnx_nodes() const;
138+
#endif
139+
132140
///@}
133141
/** @name Computational interface */
134142
///@{
@@ -163,7 +171,7 @@ class Operator : public AbstractCloneableBase<Operator<InputT, OutputT, D>>,
163171
virtual void set_proto_params(lbann_data::Operator&) const = 0;
164172
/** @brief Concrete operator description. */
165173
virtual void do_fill_description(Description&) const = 0;
166-
};
174+
}; // class Operator
167175

168176
template <typename InputT, typename OutputT, El::Device D>
169177
void Operator<InputT, OutputT, D>::write_proto(
@@ -207,5 +215,15 @@ template <typename ArchiveT>
207215
void Operator<InputT, OutputT, D>::serialize(ArchiveT& ar)
208216
{}
209217

218+
#ifdef LBANN_HAS_ONNX
219+
template <typename InputT, typename OutputT, El::Device D>
220+
std::vector<onnx::NodeProto> Operator<InputT, OutputT, D>::get_onnx_nodes() const
221+
{
222+
// The default assumption is that we don't know how to represent
223+
// this operator in ONNX terms yet.
224+
return {};
225+
}
226+
#endif
227+
210228
} // namespace lbann
211229
#endif // LBANN_OPERATORS_OPERATOR_HPP_INCLUDED

0 commit comments

Comments
 (0)