-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add rectified flow noise scheduler for accelerated diffusion model #8374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 45 commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
eef70a7
8274 Relax gpu load check (#8282)
yiheng-wang-nv f650feb
bug: Fix PatchMerging duplicate merging (#8285)
pooya-mohammadi 5da95c8
Fix test load image issue (#8297)
yiheng-wang-nv d14b6bf
Using LocalStore in Zarr v3 (#8299)
KumoLiu e516098
8267 fix normalize intensity (#8286)
advcu987 26ff1b6
Fix bundle download error from ngc source (#8307)
KumoLiu 8f4bdcf
Fix deprecated usage in zarr (#8313)
KumoLiu 106a3c8
update pydicom reader to enable gpu load (#8283)
yiheng-wang-nv 621fc5f
Zarr compression tests only with versions before 3.0 (#8319)
ericspod 3b83a56
add rectified flow noise scheduler to monai
Can-Zhao dff1a4a
Changing utils.py to test_utils.py (#8335)
ericspod 2c63f5a
8185 - Refactor test (#8231)
garciadias 2016d20
Recursive Item Mapping for Nested Lists in Compose (#8187)
KumoLiu e8b500b
Bump min torch to 1.13.1 to mitigate CVE-2022-45907 unsafe usage of e…
jamesobutler 749693b
Inferer modification - save_intermediates clashes with latent shape a…
virginiafdez 599f8a9
Fix `packaging` imports in version comparison logic (#8347)
nkaenzig 87a6c4c
Removed outdated `torch` version checks from transform functions (#8359)
nkaenzig 17440c8
Fix CommonKeys docstring (#8342)
bartosz-grabowski 90dd2cc
Add Average Precision to metrics (#8089)
thibaultdvx ab46efc
Solves path problem in test_bundle_trt_export.py (#8357)
garciadias a9a7082
8354 fix path at test onnx trt export (#8361)
garciadias cf9fb59
Modify ControlNet inferer so that it takes in context when the diffus…
virginiafdez 4b4d92c
Update monaihosting download method (#8364)
yiheng-wang-nv 092978c
Bump torch minimum to mitigate CVE-2024-31580 & CVE-2024-31583 and en…
jamesobutler 784b19f
add rectified flow for accelerated diffusion model
Can-Zhao 28c3d68
reformat
Can-Zhao dc7b8a6
reformat
Can-Zhao 0bbc0dd
reformat
Can-Zhao c070581
reformat
Can-Zhao b036450
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 81663db
add prev_original
Can-Zhao c314dbf
black
Can-Zhao e7bb70d
add doc
Can-Zhao b24af70
add doc
Can-Zhao 4499780
add doc
Can-Zhao 74e0a9b
update doc
Can-Zhao fd8d7f5
Update autoencoderkl_maisi.py
Can-Zhao ecdb812
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9909859
Update autoencoderkl_maisi.py
Can-Zhao 6726747
DCO Remediation Commit for Can Zhao <[email protected].…
Can-Zhao 0ff3034
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 454496f
Auto3DSeg algo_template hash update (#8378)
monai-bot 2df4637
rm redundant line
Can-Zhao e428c38
Enable Pytorch 2.6 (#8309)
ericspod eaa803f
conflict
Can-Zhao 8555b67
make it 2D/3D compartible, rm a outdated comment
Can-Zhao 14664e8
make it 2D/3D compartible, rm a outdated comment
Can-Zhao 20aa7fd
make it 2D/3D compartible, rm a outdated comment
Can-Zhao 3144c8a
make it 2D/3D compartible
Can-Zhao 0bf0041
add more test
Can-Zhao acb5a5c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e320ecc
reformat
Can-Zhao 80a298d
reformat
Can-Zhao c2e3cb5
add more test
Can-Zhao 40be2a6
reformat
Can-Zhao b9ceccf
reformat
Can-Zhao 9685e9f
reformat
Can-Zhao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# ========================================================================= | ||
# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py | ||
# which has the following license: | ||
# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========================================================================= | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Union | ||
|
||
import numpy as np | ||
import torch | ||
from torch.distributions import LogisticNormal | ||
|
||
from .scheduler import Scheduler | ||
|
||
|
||
def timestep_transform( | ||
t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 | ||
): | ||
""" | ||
Applies a transformation to the timestep based on image resolution scaling. | ||
|
||
Args: | ||
t (torch.Tensor): The original timestep(s). | ||
input_img_size_numel (torch.Tensor): The input image's size (H * W * D). | ||
base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel. | ||
scale (float): Scaling factor for the transformation. | ||
num_train_timesteps (int): Total number of training timesteps. | ||
spatial_dim (int): Number of spatial dimensions in the image. | ||
|
||
Returns: | ||
torch.Tensor: Transformed timestep(s). | ||
""" | ||
t = t / num_train_timesteps | ||
ratio_space = (input_img_size_numel / base_img_size_numel).pow(1.0 / spatial_dim) | ||
|
||
ratio = ratio_space * scale | ||
new_t = ratio * t / (1 + (ratio - 1) * t) | ||
|
||
new_t = new_t * num_train_timesteps | ||
return new_t | ||
|
||
|
||
class RFlowScheduler(Scheduler): | ||
""" | ||
A rectified flow scheduler for guiding the diffusion process in a generative model. | ||
|
||
Supports uniform and logit-normal sampling methods, timestep transformation for | ||
different resolutions, and noise addition during diffusion. | ||
|
||
Args: | ||
num_train_timesteps (int): Total number of training timesteps. | ||
use_discrete_timesteps (bool): Whether to use discrete timesteps. | ||
sample_method (str): Training time step sampling method ('uniform' or 'logit-normal'). | ||
loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'. | ||
scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'. | ||
use_timestep_transform (bool): Whether to apply timestep transformation. | ||
If true, there will be more inference timesteps at early(noisy) stages for larger image volumes. | ||
transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True. | ||
steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True. | ||
base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True. | ||
|
||
Example: | ||
|
||
.. code-block:: python | ||
|
||
# define a scheduler | ||
noise_scheduler = RFlowScheduler( | ||
num_train_timesteps = 1000, | ||
use_discrete_timesteps = True, | ||
sample_method = 'logit-normal', | ||
use_timestep_transform = True, | ||
base_img_size_numel = 32 * 32 * 32 | ||
) | ||
|
||
# during training | ||
inputs = torch.ones(2,4,64,64,32) | ||
noise = torch.randn_like(inputs) | ||
timesteps = noise_scheduler.sample_timesteps(inputs) | ||
noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) | ||
predicted_velocity = diffusion_unet( | ||
x=noisy_inputs, | ||
timesteps=timesteps | ||
) | ||
loss = loss_l1(predicted_velocity, (inputs - noise)) | ||
|
||
# during inference | ||
noisy_inputs = torch.randn(2,4,64,64,32) | ||
input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:]) | ||
noise_scheduler.set_timesteps( | ||
num_inference_steps=30, input_img_size_numel=input_img_size_numel) | ||
) | ||
all_next_timesteps = torch.cat( | ||
(noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) | ||
) | ||
for t, next_t in tqdm( | ||
zip(noise_scheduler.timesteps, all_next_timesteps), | ||
total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), | ||
): | ||
predicted_velocity = diffusion_unet( | ||
x=noisy_inputs, | ||
timesteps=timesteps | ||
) | ||
noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t) | ||
final_output = noisy_inputs | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_train_timesteps: int = 1000, | ||
use_discrete_timesteps: bool = True, | ||
sample_method: str = "uniform", | ||
loc: float = 0.0, | ||
scale: float = 1.0, | ||
use_timestep_transform: bool = False, | ||
transform_scale: float = 1.0, | ||
steps_offset: int = 0, | ||
base_img_size_numel: int = 32 * 32 * 32, | ||
): | ||
self.num_train_timesteps = num_train_timesteps | ||
self.use_discrete_timesteps = use_discrete_timesteps | ||
self.base_img_size_numel = base_img_size_numel | ||
|
||
# sample method | ||
if sample_method not in ["uniform", "logit-normal"]: | ||
raise ValueError( | ||
f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']." | ||
) | ||
self.sample_method = sample_method | ||
if sample_method == "logit-normal": | ||
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) | ||
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) | ||
|
||
# timestep transform | ||
self.use_timestep_transform = use_timestep_transform | ||
self.transform_scale = transform_scale | ||
self.steps_offset = steps_offset | ||
|
||
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Add noise to the original samples. | ||
|
||
Args: | ||
original_samples: original samples | ||
noise: noise to add to samples | ||
timesteps: timesteps tensor indicating the timestep to be computed for each sample. | ||
|
||
Returns: | ||
noisy_samples: sample with added noise | ||
""" | ||
timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps | ||
timepoints = 1 - timepoints # [1,1/1000] | ||
|
||
# timepoint (bsz) noise: (bsz, 4, frame, w ,h) | ||
Can-Zhao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# expand timepoint to noise shape | ||
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) | ||
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) | ||
Can-Zhao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise | ||
|
||
return noisy_samples | ||
|
||
def set_timesteps( | ||
self, | ||
num_inference_steps: int, | ||
device: str | torch.device | None = None, | ||
input_img_size_numel: int | None = None, | ||
) -> None: | ||
""" | ||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
|
||
Args: | ||
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. | ||
device: target device to put the data. | ||
input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True. | ||
""" | ||
if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1: | ||
raise ValueError( | ||
f"`num_inference_steps`: {num_inference_steps} should be at least 1, " | ||
"and cannot be larger than `self.num_train_timesteps`:" | ||
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" | ||
f" maximal {self.num_train_timesteps} timesteps." | ||
) | ||
|
||
self.num_inference_steps = num_inference_steps | ||
# prepare timesteps | ||
timesteps = [ | ||
(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps) | ||
] | ||
if self.use_discrete_timesteps: | ||
timesteps = [int(round(t)) for t in timesteps] | ||
if self.use_timestep_transform: | ||
timesteps = [ | ||
timestep_transform( | ||
t, | ||
input_img_size_numel=input_img_size_numel, | ||
base_img_size_numel=self.base_img_size_numel, | ||
num_train_timesteps=self.num_train_timesteps, | ||
) | ||
for t in timesteps | ||
] | ||
timesteps_np = np.array(timesteps).astype(np.float16) | ||
if self.use_discrete_timesteps: | ||
timesteps_np = timesteps_np.astype(np.int64) | ||
self.timesteps = torch.from_numpy(timesteps_np).to(device) | ||
self.timesteps += self.steps_offset | ||
|
||
def sample_timesteps(self, x_start): | ||
""" | ||
Randomly samples training timesteps using the chosen sampling method. | ||
|
||
Args: | ||
x_start (torch.Tensor): The input tensor for sampling. | ||
|
||
Returns: | ||
torch.Tensor: Sampled timesteps. | ||
""" | ||
if self.sample_method == "uniform": | ||
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps | ||
elif self.sample_method == "logit-normal": | ||
t = self.sample_t(x_start) * self.num_train_timesteps | ||
|
||
if self.use_discrete_timesteps: | ||
t = t.long() | ||
|
||
if self.use_timestep_transform: | ||
input_img_size_numel = torch.prod(torch.tensor(x_start.shape[-3:])) | ||
t = timestep_transform( | ||
t, | ||
input_img_size_numel=input_img_size_numel, | ||
base_img_size_numel=self.base_img_size_numel, | ||
num_train_timesteps=self.num_train_timesteps, | ||
) | ||
|
||
return t | ||
|
||
def step( | ||
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Predicts the next sample in the diffusion process. | ||
|
||
Args: | ||
model_output (torch.Tensor): Output from the trained diffusion model. | ||
timestep (int): Current timestep in the diffusion chain. | ||
sample (torch.Tensor): Current sample in the process. | ||
next_timestep (Union[int, None]): Optional next timestep. | ||
|
||
Returns: | ||
tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info. | ||
""" | ||
# Ensure num_inference_steps exists and is a valid integer | ||
if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int): | ||
raise AttributeError( | ||
"num_inference_steps is missing or not an integer in the class." | ||
"Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it." | ||
) | ||
|
||
v_pred = model_output | ||
|
||
if next_timestep is not None: | ||
next_timestep = int(next_timestep) | ||
dt: float = ( | ||
float(timestep - next_timestep) / self.num_train_timesteps | ||
) # Now next_timestep is guaranteed to be int | ||
else: | ||
dt = ( | ||
1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0 | ||
) # Avoid division by zero | ||
|
||
pred_post_sample = sample + v_pred * dt | ||
pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps | ||
|
||
return pred_post_sample, pred_original_sample |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.