From 942e07524c3c6e5d6444f0131ed7fa2e6a7e6822 Mon Sep 17 00:00:00 2001 From: Lyndon Boone Date: Wed, 7 Jul 2021 17:26:37 -0400 Subject: [PATCH 1/7] Added OneOf class Signed-off-by: Lyndon Boone --- monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 63 +++++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index fb1ff25765..b2817c95db 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose +from .compose import Compose, OneOf from .croppad.array import ( BorderPad, BoundingRect, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 7fa8d6600b..25762919e0 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -29,7 +29,7 @@ ) from monai.utils import MAX_SEED, ensure_tuple, get_seed -__all__ = ["Compose"] +__all__ = ["Compose", "OneOf"] class Compose(Randomizable, InvertibleTransform): @@ -168,3 +168,64 @@ def inverse(self, data): for t in reversed(invertible_transforms): data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) return data + + +class OneOf(Compose): + """ + ``OneOf`` provides the ability to radomly choose one transform out of a + list of callables with predfined probabilities for each. + + Args: + transforms: sequence of callables. + weights: probabilities corresponding to each callable in transforms. + Probabilities are normalized to sum to one. + + OneOf inherits from Compose and uses args map_items and unpack_items in + the same way. + """ + + def __init__( + self, + transforms: Optional[Sequence[Callable]] = None, + weights: Optional[Sequence[float]] = None, + map_items: bool = True, + unpack_items: bool = False, + ) -> None: + if transforms is None: + transforms = [] + self.transforms = ensure_tuple(transforms) + if weights is None: + if len(transforms) == 0: + weights = [] + else: + weights = [1.0 / len(transforms)] * len(transforms) + if len(weights) != len(transforms): + raise AssertionError("transforms and weights should be same size if both specified as sequences.") + self.weights = ensure_tuple(self._normalize_probabilities(weights)) + self.map_items = map_items + self.unpack_items = unpack_items + self.set_random_state(seed=get_seed()) + + def _normalize_probabilities(self, weights): + if len(weights) == 0: + return weights + else: + weights = np.array(weights) + if np.any(weights < 0): + raise AssertionError("Probabilities must be greater than or equal to zero.") + if np.all(weights == 0): + raise AssertionError("At least one probability must be greater than zero.") + weights = weights / weights.sum() + return list(weights) + + def __call__(self, input_): + if len(self.transforms) == 0: + return input_ + else: + index = self.R.multinomial(1, self.weights).argmax() + _transform = self.transforms[index] + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) + return input_ + + def inverse(self, data): + raise NotImplementedError("inverse method not yet implemented for OneOf class.") From 2d535026da16d2a70c30dcfb99c45fdd6355d41d Mon Sep 17 00:00:00 2001 From: Lyndon Boone Date: Mon, 12 Jul 2021 11:55:45 -0400 Subject: [PATCH 2/7] Clean up OneOf constructor Signed-off-by: Lyndon Boone --- monai/transforms/compose.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 25762919e0..0c5f9e4d0d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -186,25 +186,19 @@ class OneOf(Compose): def __init__( self, - transforms: Optional[Sequence[Callable]] = None, - weights: Optional[Sequence[float]] = None, + transforms: Optional[Union[Sequence[Callable], Callable]] = None, + weights: Optional[Union[Sequence[float], float]] = None, map_items: bool = True, unpack_items: bool = False, ) -> None: - if transforms is None: - transforms = [] - self.transforms = ensure_tuple(transforms) - if weights is None: - if len(transforms) == 0: - weights = [] - else: - weights = [1.0 / len(transforms)] * len(transforms) - if len(weights) != len(transforms): + super().__init__(transforms, map_items, unpack_items) + if len(self.transforms) == 0: + weights = [] + elif weights is None or isinstance(weights, float): + weights = [1.0 / len(self.transforms)] * len(self.transforms) + if len(weights) != len(self.transforms): raise AssertionError("transforms and weights should be same size if both specified as sequences.") self.weights = ensure_tuple(self._normalize_probabilities(weights)) - self.map_items = map_items - self.unpack_items = unpack_items - self.set_random_state(seed=get_seed()) def _normalize_probabilities(self, weights): if len(weights) == 0: From 4089cd460493f00ba51c43e8822908da265bf5fe Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 10 Aug 2021 17:41:23 +0100 Subject: [PATCH 3/7] add flatten, len and unit test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 16 +++++++++++ tests/test_convert_data_type.py | 51 +++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 tests/test_convert_data_type.py diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 0c5f9e4d0d..7bfa9b03c9 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -223,3 +223,19 @@ def __call__(self, input_): def inverse(self, data): raise NotImplementedError("inverse method not yet implemented for OneOf class.") + + def flatten(self): + transforms = [] + weights = [] + for t, w in zip(self.transforms, self.weights): + # if nested, probability is the current weight multiplied by the nested weights, + # and so on recursively + if isinstance(t, OneOf): + tr = t.flatten() + for t_, w_ in zip(tr.transforms, tr.weights): + transforms.append(t_) + weights.append(w_ * w) + else: + transforms.append(t) + weights.append(w) + return OneOf(transforms, weights, self.map_items, self.unpack_items) diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py new file mode 100644 index 0000000000..06d6b36090 --- /dev/null +++ b/tests/test_convert_data_type.py @@ -0,0 +1,51 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from parameterized import parameterized + +from monai.transforms import OneOf, Transform + + +class X(Transform): + def __call__(self, x): + return x + + +class Y(Transform): + def __call__(self, x): + return x + + +TESTS = [ + ((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25)), +] + + +class TestOneOf(unittest.TestCase): + @parameterized.expand(TESTS) + def test_one_of(self, transforms, input_weights, expected_weights): + tr = OneOf(transforms, input_weights) + self.assertTupleEqual(tr.weights, expected_weights) + + def test_len_and_flatten(self): + p1 = OneOf((X(), Y()), (1, 3)) # 0.25, 0.75 + p2 = OneOf((Y(), Y()), (2, 2)) # 0.5. 0.5 + p = OneOf((p1, p2, X()), (1, 2, 1)) # 0.25, 0.5, 0.25 + expected_order = (X, Y, Y, Y, X) + expected_weights = (0.25 * 0.25, 0.25 * 0.75, 0.5 * 0.5, 0.5 * 0.5, 0.25) + self.assertEqual(len(p), len(expected_order)) + self.assertTupleEqual(p.flatten().weights, expected_weights) + + +if __name__ == "__main__": + unittest.main() From bcdde1bc634017adafea08cb9b0c6176219aabe9 Mon Sep 17 00:00:00 2001 From: Lyndon Boone Date: Tue, 10 Aug 2021 16:14:45 -0400 Subject: [PATCH 4/7] Added unit tests and inverse method Signed-off-by: Lyndon Boone --- monai/transforms/compose.py | 32 +++++++++++------- tests/test_convert_data_type.py | 58 +++++++++++++++++++++++++++++++-- 2 files changed, 76 insertions(+), 14 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index c385c443db..df737487d4 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -199,6 +199,7 @@ def __init__( if len(weights) != len(self.transforms): raise AssertionError("transforms and weights should be same size if both specified as sequences.") self.weights = ensure_tuple(self._normalize_probabilities(weights)) + self.index = None def _normalize_probabilities(self, weights): if len(weights) == 0: @@ -212,18 +213,6 @@ def _normalize_probabilities(self, weights): weights = weights / weights.sum() return list(weights) - def __call__(self, input_): - if len(self.transforms) == 0: - return input_ - else: - index = self.R.multinomial(1, self.weights).argmax() - _transform = self.transforms[index] - input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) - return input_ - - def inverse(self, data): - raise NotImplementedError("inverse method not yet implemented for OneOf class.") - def flatten(self): transforms = [] weights = [] @@ -239,3 +228,22 @@ def flatten(self): transforms.append(t) weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) + + def __call__(self, input_): + if len(self.transforms) == 0: + return input_ + else: + index = self.R.multinomial(1, self.weights).argmax() + self.index = index + _transform = self.transforms[index] + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) + return input_ + + def inverse(self, data): + index = self.index + if index is None: + return data + t = self.transforms[index] + if isinstance(t, InvertibleTransform): + data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) + return data diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index 06d6b36090..082e94ca1f 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -13,7 +13,7 @@ from parameterized import parameterized -from monai.transforms import OneOf, Transform +from monai.transforms import InvertibleTransform, OneOf, Randomizable, Transform class X(Transform): @@ -26,6 +26,34 @@ def __call__(self, x): return x +class A(Transform): + def __call__(self, x): + return x + 1 + + +class B(Transform): + def __call__(self, x): + return x + 2 + + +class C(Transform): + def __call__(self, x): + return x + 3 + + +class Inv(InvertibleTransform): + def __call__(self, x): + return x + 1 + + def inverse(self, x): + return x - 1 + + +class NonInv(Randomizable): + def __call__(self, x): + return x + self.R.uniform(-1, 1) + + TESTS = [ ((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25)), ] @@ -33,10 +61,15 @@ def __call__(self, x): class TestOneOf(unittest.TestCase): @parameterized.expand(TESTS) - def test_one_of(self, transforms, input_weights, expected_weights): + def test_normalize_weights(self, transforms, input_weights, expected_weights): tr = OneOf(transforms, input_weights) self.assertTupleEqual(tr.weights, expected_weights) + def test_no_weights_arg(self): + p = OneOf((X(), Y(), X(), Y())) + expected_weights = (0.25,) * 4 + self.assertTupleEqual(p.weights, expected_weights) + def test_len_and_flatten(self): p1 = OneOf((X(), Y()), (1, 3)) # 0.25, 0.75 p2 = OneOf((Y(), Y()), (2, 2)) # 0.5. 0.5 @@ -46,6 +79,27 @@ def test_len_and_flatten(self): self.assertEqual(len(p), len(expected_order)) self.assertTupleEqual(p.flatten().weights, expected_weights) + def test_inverse(self): + p = OneOf((OneOf((Inv(), NonInv())), Inv(), NonInv())) + for _i in range(20): + out = p(2.0) + inverted = p.inverse(out) + if p.index == 0 and p.transforms[0].index == 0 or p.index == 1: + self.assertEqual(out, 3.0) + self.assertEqual(inverted, 2.0) + else: + self.assertEqual(inverted, out) + + def test_one_of(self): + p = OneOf((A(), B(), C()), (1, 2, 1)) + counts = [0] * 3 + for _i in range(10000): + out = p(1.0) + counts[int(out - 2)] += 1 + self.assertAlmostEqual(counts[0] / 10000, 0.25, delta=1.0) + self.assertAlmostEqual(counts[1] / 10000, 0.50, delta=1.0) + self.assertAlmostEqual(counts[2] / 10000, 0.25, delta=1.0) + if __name__ == "__main__": unittest.main() From 204895f7f30f43b5da92bb009033dced7a6731ba Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 11 Aug 2021 10:39:43 +0100 Subject: [PATCH 5/7] rename test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/{test_convert_data_type.py => test_one_of.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_convert_data_type.py => test_one_of.py} (100%) diff --git a/tests/test_convert_data_type.py b/tests/test_one_of.py similarity index 100% rename from tests/test_convert_data_type.py rename to tests/test_one_of.py From a4c33a88733d295a912a42a7a88556cbdd2ff4b7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 11 Aug 2021 11:53:39 +0100 Subject: [PATCH 6/7] flatten tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 2 +- tests/test_one_of.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index df737487d4..149a7bacad 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -143,7 +143,7 @@ def flatten(self): """ new_transforms = [] for t in self.transforms: - if isinstance(t, Compose): + if isinstance(t, Compose) and not isinstance(t, OneOf): new_transforms += t.flatten().transforms else: new_transforms.append(t) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 082e94ca1f..79e27948c1 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -14,6 +14,7 @@ from parameterized import parameterized from monai.transforms import InvertibleTransform, OneOf, Randomizable, Transform +from monai.transforms.compose import Compose class X(Transform): @@ -79,9 +80,23 @@ def test_len_and_flatten(self): self.assertEqual(len(p), len(expected_order)) self.assertTupleEqual(p.flatten().weights, expected_weights) + def test_compose_flatten_does_not_affect_one_of(self): + p = Compose([A(), B(), OneOf([C(), Inv(), Compose([X(), Y()])])]) + f = p.flatten() + # in this case the flattened transform should be the same. + + def _match(a, b): + self.assertEqual(type(a), type(b)) + for a_, b_ in zip(a.transforms, b.transforms): + self.assertEqual(type(a_), type(b_)) + if isinstance(a_, (Compose, OneOf)): + _match(a_, b_) + + _match(p, f) + def test_inverse(self): p = OneOf((OneOf((Inv(), NonInv())), Inv(), NonInv())) - for _i in range(20): + for _ in range(20): out = p(2.0) inverted = p.inverse(out) if p.index == 0 and p.transforms[0].index == 0 or p.index == 1: From 77b3bed7e5c6bb65a1e581b02dc1c666786f612d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 12 Aug 2021 13:48:34 +0100 Subject: [PATCH 7/7] add inverse Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 46 +++++++++++----- tests/test_one_of.py | 101 +++++++++++++++++++++++++++++------- 2 files changed, 114 insertions(+), 33 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 149a7bacad..8737abd0fa 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,7 +13,7 @@ """ import warnings -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Union import numpy as np @@ -28,6 +28,7 @@ apply_transform, ) from monai.utils import MAX_SEED, ensure_tuple, get_seed +from monai.utils.enums import InverseKeys __all__ = ["Compose", "OneOf"] @@ -199,7 +200,6 @@ def __init__( if len(weights) != len(self.transforms): raise AssertionError("transforms and weights should be same size if both specified as sequences.") self.weights = ensure_tuple(self._normalize_probabilities(weights)) - self.index = None def _normalize_probabilities(self, weights): if len(weights) == 0: @@ -229,21 +229,41 @@ def flatten(self): weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) - def __call__(self, input_): + def __call__(self, data): if len(self.transforms) == 0: - return input_ + return data else: index = self.R.multinomial(1, self.weights).argmax() - self.index = index _transform = self.transforms[index] - input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) - return input_ + data = apply_transform(_transform, data, self.map_items, self.unpack_items) + # if the data is a mapping (dictionary), append the OneOf transform to the end + if isinstance(data, Mapping): + for key in data.keys(): + if key + InverseKeys.KEY_SUFFIX in data: + self.push_transform(data, key, extra_info={"index": index}) + return data def inverse(self, data): - index = self.index - if index is None: + if len(self.transforms) == 0: return data - t = self.transforms[index] - if isinstance(t, InvertibleTransform): - data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) - return data + if not isinstance(data, Mapping): + raise RuntimeError("Inverse only implemented for Mapping (dictionary) data") + + # loop until we get an index and then break (since they'll all be the same) + index = None + for key in data.keys(): + if key + InverseKeys.KEY_SUFFIX in data: + # get the index of the applied OneOf transform + index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"] + # and then remove the OneOf transform + self.pop_transform(data, key) + if index is None: + raise RuntimeError("No invertible transforms have been applied") + + # if applied transform is not InvertibleTransform, throw error + _transform = self.transforms[index] + if not isinstance(_transform, InvertibleTransform): + raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).") + + # apply the inverse + return _transform.inverse(data) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 79e27948c1..d45d0f3f61 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -10,11 +10,14 @@ # limitations under the License. import unittest +from copy import deepcopy from parameterized import parameterized -from monai.transforms import InvertibleTransform, OneOf, Randomizable, Transform +from monai.transforms import InvertibleTransform, OneOf, Transform from monai.transforms.compose import Compose +from monai.transforms.transform import MapTransform +from monai.utils.enums import InverseKeys class X(Transform): @@ -42,23 +45,66 @@ def __call__(self, x): return x + 3 -class Inv(InvertibleTransform): - def __call__(self, x): - return x + 1 +class MapBase(MapTransform): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn, self.inv_fn = None, None - def inverse(self, x): - return x - 1 + def __call__(self, data): + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + d[key] = self.fwd_fn(d[key]) + return d -class NonInv(Randomizable): - def __call__(self, x): - return x + self.R.uniform(-1, 1) +class NonInv(MapBase): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn = lambda x: x * 2 + + +class Inv(MapBase, InvertibleTransform): + def __call__(self, data): + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + d[key] = self.fwd_fn(d[key]) + self.push_transform(d, key) + return d + + def inverse(self, data): + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + d[key] = self.inv_fn(d[key]) + self.pop_transform(d, key) + return d + + +class InvA(Inv): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn = lambda x: x + 1 + self.inv_fn = lambda x: x - 1 + + +class InvB(Inv): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn = lambda x: x + 100 + self.inv_fn = lambda x: x - 100 TESTS = [ ((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25)), ] +KEYS = ["x", "y"] +TEST_INVERSES = [ + (OneOf((InvA(KEYS), InvB(KEYS))), True), + (OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True), + (OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True), + (OneOf((NonInv(KEYS), NonInv(KEYS))), False), +] + class TestOneOf(unittest.TestCase): @parameterized.expand(TESTS) @@ -81,7 +127,7 @@ def test_len_and_flatten(self): self.assertTupleEqual(p.flatten().weights, expected_weights) def test_compose_flatten_does_not_affect_one_of(self): - p = Compose([A(), B(), OneOf([C(), Inv(), Compose([X(), Y()])])]) + p = Compose([A(), B(), OneOf([C(), Inv(KEYS), Compose([X(), Y()])])]) f = p.flatten() # in this case the flattened transform should be the same. @@ -94,16 +140,31 @@ def _match(a, b): _match(p, f) - def test_inverse(self): - p = OneOf((OneOf((Inv(), NonInv())), Inv(), NonInv())) - for _ in range(20): - out = p(2.0) - inverted = p.inverse(out) - if p.index == 0 and p.transforms[0].index == 0 or p.index == 1: - self.assertEqual(out, 3.0) - self.assertEqual(inverted, 2.0) - else: - self.assertEqual(inverted, out) + @parameterized.expand(TEST_INVERSES) + def test_inverse(self, transform, should_be_ok): + data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} + fwd_data = transform(data) + if not should_be_ok: + with self.assertRaises(RuntimeError): + transform.inverse(fwd_data) + return + + for k in KEYS: + t = fwd_data[k + InverseKeys.KEY_SUFFIX][-1] + # make sure the OneOf index was stored + self.assertEqual(t[InverseKeys.CLASS_NAME], OneOf.__name__) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[InverseKeys.EXTRA_INFO]["index"] < len(transform)) + + # call the inverse + fwd_inv_data = transform.inverse(fwd_data) + + for k in KEYS: + # check transform was removed + self.assertTrue(len(fwd_inv_data[k + InverseKeys.KEY_SUFFIX]) < len(fwd_data[k + InverseKeys.KEY_SUFFIX])) + # check data is same as original (and different from forward) + self.assertEqual(fwd_inv_data[k], data[k]) + self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) def test_one_of(self): p = OneOf((A(), B(), C()), (1, 2, 1))