Skip to content

Commit ac9dc3f

Browse files
authored
[geometry] Add Shape::Visit (#21181)
Code that needs to operate on concrete Shape subclasses can now do so without needing to implement an entire Reifier (though that remains a fine and well-supported option when suitable).
1 parent 097d39f commit ac9dc3f

File tree

6 files changed

+203
-145
lines changed

6 files changed

+203
-145
lines changed

geometry/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ drake_cc_library(
397397
],
398398
deps = [
399399
"//common:nice_type_name",
400+
"//common:overloaded",
400401
"//geometry/proximity:make_convex_hull_mesh_impl",
401402
"//geometry/proximity:meshing_utilities",
402403
"//geometry/proximity:obj_to_surface_mesh",
@@ -1014,6 +1015,7 @@ drake_cc_googletest(
10141015
deps = [
10151016
":shape_specification",
10161017
"//common:find_resource",
1018+
"//common:overloaded",
10171019
"//common/test_utilities",
10181020
],
10191021
)

geometry/shape_specification.cc

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "drake/common/drake_throw.h"
1111
#include "drake/common/nice_type_name.h"
12+
#include "drake/common/overloaded.h"
1213
#include "drake/geometry/proximity/make_convex_hull_mesh_impl.h"
1314
#include "drake/geometry/proximity/meshing_utilities.h"
1415
#include "drake/geometry/proximity/obj_to_surface_mesh.h"
@@ -322,68 +323,56 @@ double CalcMeshVolumeFromFile(const MeshType& mesh) {
322323
return internal::CalcEnclosedVolume(surface_mesh);
323324
}
324325

325-
class CalcVolumeReifier final : public ShapeReifier {
326-
public:
327-
CalcVolumeReifier() = default;
328-
329-
using ShapeReifier::ImplementGeometry;
330-
331-
void ImplementGeometry(const Box& box, void*) final {
332-
volume_ = box.width() * box.depth() * box.height();
333-
}
334-
void ImplementGeometry(const Capsule& capsule, void*) final {
335-
volume_ = M_PI * std::pow(capsule.radius(), 2) * capsule.length() +
336-
4.0 / 3.0 * M_PI * std::pow(capsule.radius(), 3);
337-
}
338-
void ImplementGeometry(const Convex& mesh, void*) {
339-
volume_ = CalcMeshVolumeFromFile(mesh);
340-
}
341-
void ImplementGeometry(const Cylinder& cylinder, void*) final {
342-
volume_ = M_PI * std::pow(cylinder.radius(), 2) * cylinder.length();
343-
}
344-
void ImplementGeometry(const Ellipsoid& ellipsoid, void*) final {
345-
volume_ = 4.0 / 3.0 * M_PI * ellipsoid.a() * ellipsoid.b() * ellipsoid.c();
346-
}
347-
void ImplementGeometry(const HalfSpace&, void*) final {
348-
volume_ = std::numeric_limits<double>::infinity();
349-
}
350-
void ImplementGeometry(const Mesh& mesh, void*) {
351-
volume_ = CalcMeshVolumeFromFile(mesh);
352-
}
353-
void ImplementGeometry(const MeshcatCone& cone, void*) final {
354-
volume_ = 1.0 / 3.0 * M_PI * cone.a() * cone.b() * cone.height();
355-
}
356-
void ImplementGeometry(const Sphere& sphere, void*) final {
357-
volume_ = 4.0 / 3.0 * M_PI * std::pow(sphere.radius(), 3);
358-
}
359-
360-
double volume() const { return volume_; }
361-
362-
private:
363-
double volume_{0.0};
364-
};
365-
366326
} // namespace
367327

368328
double CalcVolume(const Shape& shape) {
369-
CalcVolumeReifier reifier;
370-
shape.Reify(&reifier);
371-
return reifier.volume();
329+
return shape.Visit<double>(overloaded{
330+
[](const Box& box) {
331+
return box.width() * box.depth() * box.height();
332+
},
333+
[](const Capsule& capsule) {
334+
return M_PI * std::pow(capsule.radius(), 2) * capsule.length() +
335+
4.0 / 3.0 * M_PI * std::pow(capsule.radius(), 3);
336+
},
337+
[](const Convex& mesh) {
338+
return CalcMeshVolumeFromFile(mesh);
339+
},
340+
[](const Cylinder& cylinder) {
341+
return M_PI * std::pow(cylinder.radius(), 2) * cylinder.length();
342+
},
343+
[](const Ellipsoid& ellipsoid) {
344+
return 4.0 / 3.0 * M_PI * ellipsoid.a() * ellipsoid.b() * ellipsoid.c();
345+
},
346+
[](const HalfSpace&) {
347+
return std::numeric_limits<double>::infinity();
348+
},
349+
[](const Mesh& mesh) {
350+
return CalcMeshVolumeFromFile(mesh);
351+
},
352+
[](const MeshcatCone& cone) {
353+
return 1.0 / 3.0 * M_PI * cone.a() * cone.b() * cone.height();
354+
},
355+
[](const Sphere& sphere) {
356+
return 4.0 / 3.0 * M_PI * std::pow(sphere.radius(), 3);
357+
}});
372358
}
373359

374360
// The NVI function definitions are enough boilerplate to merit a macro to
375361
// implement them, and we might as well toss in the dtor for good measure.
376362

377-
#define DRAKE_DEFINE_SHAPE_SUBCLASS_BOILERPLATE(ShapeType) \
378-
ShapeType::~ShapeType() = default; \
379-
void ShapeType::DoReify(ShapeReifier* shape_reifier, void* user_data) \
380-
const { \
381-
shape_reifier->ImplementGeometry(*this, user_data); \
382-
} \
383-
std::unique_ptr<Shape> ShapeType::DoClone() const { \
384-
return std::unique_ptr<ShapeType>(new ShapeType(*this)); \
385-
} \
386-
std::string_view ShapeType::do_type_name() const { return #ShapeType; }
363+
#define DRAKE_DEFINE_SHAPE_SUBCLASS_BOILERPLATE(ShapeType) \
364+
ShapeType::~ShapeType() = default; \
365+
void ShapeType::DoReify(ShapeReifier* shape_reifier, void* user_data) \
366+
const { \
367+
shape_reifier->ImplementGeometry(*this, user_data); \
368+
} \
369+
std::unique_ptr<Shape> ShapeType::DoClone() const { \
370+
return std::unique_ptr<ShapeType>(new ShapeType(*this)); \
371+
} \
372+
std::string_view ShapeType::do_type_name() const { return #ShapeType; } \
373+
Shape::VariantShapeConstPtr ShapeType::get_variant_this() const { \
374+
return this; \
375+
}
387376

388377
DRAKE_DEFINE_SHAPE_SUBCLASS_BOILERPLATE(Box)
389378
DRAKE_DEFINE_SHAPE_SUBCLASS_BOILERPLATE(Capsule)

geometry/shape_specification.h

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <memory>
44
#include <string>
5+
#include <variant>
56

67
#include "drake/common/drake_copyable.h"
78
#include "drake/common/drake_deprecated.h"
@@ -36,16 +37,15 @@ class Sphere;
3637
//
3738
// When you add a new subclass of Shape to Drake, you must:
3839
//
39-
// 1. Add a virtual function ImplementGeometry() for the new shape in
40+
// 1. Adjust the VariantShapeConstPtr typedef to list the new subclass.
41+
//
42+
// 2. Add a virtual function ImplementGeometry() for the new shape in
4043
// ShapeReifier that invokes the ThrowUnsupportedGeometry method, and add to
4144
// the test for it in shape_specification_test.cc.
4245
//
43-
// 2. Implement ImplementGeometry in derived ShapeReifiers to continue support
44-
// if desired, otherwise ensure unimplemented functions are not hidden in new
45-
// derivations of ShapeReifier with `using`, for example, `using
46-
// ShapeReifier::ImplementGeometry`. Existing subclasses should already have
47-
// this. Otherwise, you might get a runtime error; we do not have an
48-
// automatic way to enforce them at compile time.
46+
// 3. Grep Drake for the line `using ShapeReifier::ImplementGeometry;` (a trick
47+
// that selects a default-throw implementation) and choose which (if any) of
48+
// those reifiers you want to add support for this new shape into.
4949

5050
/** The abstract base class for all shape specifications. Concrete subclasses
5151
exist for specific shapes (e.g., Box, Mesh, etc.).
@@ -76,6 +76,33 @@ class Shape {
7676
/** Returns a string representation of this shape. */
7777
std::string to_string() const { return do_to_string(); }
7878

79+
/** Calls the given `visitor` function with `*this` as the sole argument, but
80+
with `*this` downcast to be the shape's concrete subclass. For example, if
81+
this shape is a %Box then calls `visitor(static_cast<const Box&>(*this))`.
82+
@tparam ReturnType The return type to coerce return values into. When not
83+
`void`, anything returned by the visitor must be implicitly convertible to
84+
this type. When `void`, the return type will be whatever the Vistor's call
85+
operator returns by default.
86+
87+
To see examples of how this is used, you can check the Drake source code,
88+
e.g., check the implementation of CalcVolume() for one example. */
89+
template <typename ReturnType = void, typename Visitor>
90+
decltype(auto) Visit(Visitor&& visitor) const {
91+
if constexpr (std::is_same_v<ReturnType, void>) {
92+
return std::visit(
93+
[&visitor](auto* shape) {
94+
return visitor(*shape);
95+
},
96+
get_variant_this());
97+
} else {
98+
return std::visit(
99+
[&visitor](auto* shape) -> ReturnType {
100+
return visitor(*shape);
101+
},
102+
get_variant_this());
103+
}
104+
}
105+
79106
protected:
80107
/** (Internal use only) Constructor for use by derived classes.
81108
All subclasses of Shape must be marked `final`. */
@@ -102,6 +129,21 @@ class Shape {
102129

103130
/** (Internal use only) NVI for to_string(). */
104131
virtual std::string do_to_string() const = 0;
132+
133+
/** (Internal use only) All concrete subclasses, as const pointers. */
134+
using VariantShapeConstPtr = std::variant< //
135+
const Box*, //
136+
const Capsule*, //
137+
const Convex*, //
138+
const Cylinder*, //
139+
const Ellipsoid*, //
140+
const HalfSpace*, //
141+
const Mesh*, //
142+
const MeshcatCone*, //
143+
const Sphere*>;
144+
145+
/** (Internal use only) NVI-like helper function for Visit(). */
146+
virtual VariantShapeConstPtr get_variant_this() const = 0;
105147
};
106148

107149
/** Definition of a box. The box is centered on the origin of its canonical
@@ -146,6 +188,7 @@ class Box final : public Shape {
146188
std::unique_ptr<Shape> DoClone() const final;
147189
std::string_view do_type_name() const final;
148190
std::string do_to_string() const final;
191+
VariantShapeConstPtr get_variant_this() const final;
149192

150193
Vector3<double> size_;
151194
};
@@ -180,6 +223,7 @@ class Capsule final : public Shape {
180223
std::unique_ptr<Shape> DoClone() const final;
181224
std::string_view do_type_name() const final;
182225
std::string do_to_string() const final;
226+
VariantShapeConstPtr get_variant_this() const final;
183227

184228
double radius_{};
185229
double length_{};
@@ -246,6 +290,7 @@ class Convex final : public Shape {
246290
std::unique_ptr<Shape> DoClone() const final;
247291
std::string_view do_type_name() const final;
248292
std::string do_to_string() const final;
293+
VariantShapeConstPtr get_variant_this() const final;
249294

250295
std::string filename_;
251296
std::string extension_;
@@ -279,6 +324,7 @@ class Cylinder final : public Shape {
279324
std::unique_ptr<Shape> DoClone() const final;
280325
std::string_view do_type_name() const final;
281326
std::string do_to_string() const final;
327+
VariantShapeConstPtr get_variant_this() const final;
282328

283329
double radius_{};
284330
double length_{};
@@ -321,6 +367,7 @@ class Ellipsoid final : public Shape {
321367
std::unique_ptr<Shape> DoClone() const final;
322368
std::string_view do_type_name() const final;
323369
std::string do_to_string() const final;
370+
VariantShapeConstPtr get_variant_this() const final;
324371

325372
Vector3<double> radii_;
326373
};
@@ -361,6 +408,7 @@ class HalfSpace final : public Shape {
361408
std::unique_ptr<Shape> DoClone() const final;
362409
std::string_view do_type_name() const final;
363410
std::string do_to_string() const final;
411+
VariantShapeConstPtr get_variant_this() const final;
364412
};
365413

366414
// TODO(DamrongGuoy): Update documentation when mesh is fully supported (i.e.,
@@ -440,6 +488,7 @@ class Mesh final : public Shape {
440488
std::unique_ptr<Shape> DoClone() const final;
441489
std::string_view do_type_name() const final;
442490
std::string do_to_string() const final;
491+
VariantShapeConstPtr get_variant_this() const final;
443492

444493
// NOTE: Cannot be const to support default copy/move semantics.
445494
std::string filename_;
@@ -488,6 +537,7 @@ class MeshcatCone final : public Shape {
488537
std::unique_ptr<Shape> DoClone() const final;
489538
std::string_view do_type_name() const final;
490539
std::string do_to_string() const final;
540+
VariantShapeConstPtr get_variant_this() const final;
491541

492542
double height_{};
493543
double a_{};
@@ -514,6 +564,7 @@ class Sphere final : public Shape {
514564
std::unique_ptr<Shape> DoClone() const final;
515565
std::string_view do_type_name() const final;
516566
std::string do_to_string() const final;
567+
VariantShapeConstPtr get_variant_this() const final;
517568

518569
double radius_{};
519570
};

geometry/test/shape_specification_test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "drake/common/find_resource.h"
99
#include "drake/common/fmt_eigen.h"
10+
#include "drake/common/overloaded.h"
1011
#include "drake/common/test_utilities/eigen_matrix_compare.h"
1112
#include "drake/common/test_utilities/expect_no_throw.h"
1213
#include "drake/common/test_utilities/expect_throws_message.h"
@@ -253,6 +254,49 @@ TEST_F(ReifierTest, CloningShapes) {
253254
ASSERT_TRUE(sphere_made_);
254255
}
255256

257+
GTEST_TEST(VisitTest, ReturnTypeVoid) {
258+
const Box box(1.0, 2.0, 3.0);
259+
box.Visit(overloaded{
260+
[&](const Box& arg) {
261+
EXPECT_EQ(&arg, &box);
262+
},
263+
[](const auto&) {
264+
GTEST_FAIL();
265+
},
266+
});
267+
268+
const Sphere sphere(1.0);
269+
sphere.Visit(overloaded{
270+
[&](const Sphere& arg) {
271+
EXPECT_EQ(&arg, &sphere);
272+
},
273+
[](const auto&) {
274+
GTEST_FAIL();
275+
},
276+
});
277+
}
278+
279+
GTEST_TEST(VisitTest, ReturnTypeConversion) {
280+
const Box box(1.0, 2.0, 3.0);
281+
const Sphere sphere(1.0);
282+
283+
auto get_size = overloaded{
284+
[](const Box& arg) {
285+
return arg.size();
286+
},
287+
[](const Sphere& arg) {
288+
return Vector1d(arg.radius());
289+
},
290+
[](const auto&) -> Eigen::VectorXd {
291+
DRAKE_UNREACHABLE();
292+
},
293+
};
294+
295+
Eigen::VectorXd dims;
296+
dims = box.Visit<Eigen::VectorXd>(get_size);
297+
dims = sphere.Visit<Eigen::VectorXd>(get_size);
298+
}
299+
256300
// Given the pose of a plane and its expected translation and z-axis, confirms
257301
// that the pose conforms to expectations. Also confirms that the rotational
258302
// component is orthonormal.

multibody/tree/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ drake_cc_library(
257257
hdrs = ["geometry_spatial_inertia.h"],
258258
deps = [
259259
":spatial_inertia",
260+
"//common:overloaded",
260261
"//geometry:shape_specification",
261262
"//geometry/proximity:make_mesh_from_vtk",
262263
"//geometry/proximity:obj_to_surface_mesh",

0 commit comments

Comments
 (0)