Skip to content

Commit 6bdc8c6

Browse files
authored
Compose len (#1622)
* Compose len Signed-off-by: Richard Brown <[email protected]> * Compose.flatten() Signed-off-by: Richard Brown <[email protected]>
1 parent a894adc commit 6bdc8c6

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

monai/transforms/compose.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,26 @@ def randomize(self, data: Optional[Any] = None) -> None:
231231
f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning
232232
)
233233

234+
def flatten(self):
235+
"""Return a Composition with a simple list of transforms, as opposed to any nested Compositions.
236+
237+
e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()`
238+
will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`.
239+
240+
"""
241+
new_transforms = []
242+
for t in self.transforms:
243+
if isinstance(t, Compose):
244+
new_transforms += t.flatten().transforms
245+
else:
246+
new_transforms.append(t)
247+
248+
return Compose(new_transforms)
249+
250+
def __len__(self):
251+
"""Return number of transformations."""
252+
return len(self.flatten().transforms)
253+
234254
def __call__(self, input_):
235255
for _transform in self.transforms:
236256
input_ = apply_transform(_transform, input_)

tests/test_compose.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ def test_data_loader_2(self):
156156
self.assertAlmostEqual(out_1.cpu().item(), 0.131966779)
157157
set_determinism(None)
158158

159+
def test_flatten_and_len(self):
160+
x = AddChannel()
161+
t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])])
162+
163+
t2 = t1.flatten()
164+
for t in t2.transforms:
165+
self.assertNotIsInstance(t, Compose)
166+
167+
# test len
168+
self.assertEqual(len(t1), 8)
169+
159170

160171
if __name__ == "__main__":
161172
unittest.main()

0 commit comments

Comments
 (0)