Skip to content

Commit ac6d2cd

Browse files
saran-tcopybara-github
authored andcommitted
Speed up warning checking in dm_control.
`MjData.warning.number` is a property, and accessing it incurs a roundtrip penalty to pybind11. Instead, we now grab a reference to the returned array once. Also revert to using `np.copyto` to copy warning states -- we used to do this before the migration to the new bindings, and profiling shows that it is indeed faster. PiperOrigin-RevId: 465284823 Change-Id: Idf0b2c26134db9fadd52d52c313ae3363c0da818
1 parent 58f0821 commit ac6d2cd

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

dm_control/mujoco/engine.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,9 @@ def check_invalid_state(self):
324324
context is nested inside a `suppress_physics_errors` context, in which
325325
case a warning will be logged instead.
326326
"""
327-
self._warnings_before[:] = [w.number for w in self._warnings]
327+
np.copyto(self._warnings_before, self._warnings)
328328
yield
329-
np.greater([w.number for w in self._warnings],
330-
self._warnings_before,
331-
out=self._new_warnings)
329+
np.greater(self._warnings, self._warnings_before, out=self._new_warnings)
332330
if any(self._new_warnings):
333331
warning_names = np.compress(self._new_warnings,
334332
list(mujoco.mjtWarning.__members__))
@@ -379,7 +377,7 @@ def _reload_from_data(self, data):
379377

380378
# Performance optimization: pre-allocate numpy arrays used when checking for
381379
# MuJoCo warnings on each step.
382-
self._warnings = self.data.warning
380+
self._warnings = self.data.warning.number
383381
self._warnings_before = np.empty_like(self._warnings)
384382
self._new_warnings = np.empty(dtype=bool, shape=(len(self._warnings),))
385383

0 commit comments

Comments
 (0)