diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 452160ae0c..e660a1abb6 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -34,9 +34,22 @@ import torch from torch.distributions import LogisticNormal +from monai.utils import StrEnum + +from .ddpm import DDPMPredictionType from .scheduler import Scheduler +class RFlowPredictionType(StrEnum): + """ + Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument. + + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + V_PREDICTION = DDPMPredictionType.V_PREDICTION + + def timestep_transform( t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 ): @@ -143,6 +156,9 @@ def __init__( base_img_size_numel: int = 32 * 32 * 32, spatial_dim: int = 3, ): + # rectified flow only accepts velocity prediction + self.prediction_type = RFlowPredictionType.V_PREDICTION + self.num_train_timesteps = num_train_timesteps self.use_discrete_timesteps = use_discrete_timesteps self.base_img_size_numel = base_img_size_numel