From 07735eacf6d2d7c212786d7d49743523e5a21fbd Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 11 Mar 2025 23:50:05 +0000 Subject: [PATCH 1/3] add prediction type Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 452160ae0c..522c86b8d1 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -34,9 +34,20 @@ import torch from torch.distributions import LogisticNormal +from monai.utils import StrEnum + from .scheduler import Scheduler +from .ddpm import DDPMPredictionType + +class RFlowPredictionType(StrEnum): + """ + Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument. + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + V_PREDICTION = DDPMPredictionType.V_PREDICTION + def timestep_transform( t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 ): @@ -143,6 +154,9 @@ def __init__( base_img_size_numel: int = 32 * 32 * 32, spatial_dim: int = 3, ): + # rectified flow only accepts velocity prediction + self.prediction_type = RFlowPredictionType.V_PREDICTION + self.num_train_timesteps = num_train_timesteps self.use_discrete_timesteps = use_discrete_timesteps self.base_img_size_numel = base_img_size_numel From 5d40b5fcc3c77655c8dd77a75580f2a80972d3c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 23:51:04 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/schedulers/rectified_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 522c86b8d1..74673d6d04 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -47,7 +47,7 @@ class RFlowPredictionType(StrEnum): v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf """ V_PREDICTION = DDPMPredictionType.V_PREDICTION - + def timestep_transform( t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 ): @@ -156,7 +156,7 @@ def __init__( ): # rectified flow only accepts velocity prediction self.prediction_type = RFlowPredictionType.V_PREDICTION - + self.num_train_timesteps = num_train_timesteps self.use_discrete_timesteps = use_discrete_timesteps self.base_img_size_numel = base_img_size_numel From 08ea31f55e80fea9f8685accc6c9661268583725 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Mar 2025 00:06:14 +0000 Subject: [PATCH 3/3] reformat Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 74673d6d04..e660a1abb6 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -36,8 +36,8 @@ from monai.utils import StrEnum -from .scheduler import Scheduler from .ddpm import DDPMPredictionType +from .scheduler import Scheduler class RFlowPredictionType(StrEnum): @@ -46,8 +46,10 @@ class RFlowPredictionType(StrEnum): v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf """ + V_PREDICTION = DDPMPredictionType.V_PREDICTION + def timestep_transform( t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 ):