-
Notifications
You must be signed in to change notification settings - Fork 286
MeanFlows implementation. #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hi @gumran! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
| "# flow_matching\n", | ||
| "from flow_matching.path.scheduler import CondOTScheduler\n", | ||
| "from flow_matching.path import AffineProbPath\n", | ||
| "from flow_matching.solver import Solver, ODESolver\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These imports seem to be unused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the solver imports are unused.
| " x_1 = torch.randn_like(x_0).to(device)\n", | ||
| "\n", | ||
| " # sample two time points from a logit-normal distribution\n", | ||
| " r = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't these x_1.shape[0] references for r and t be replaced with batch_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I echoed the original 2d_flow_matching.ipynb notebook, but you're right.
|
Thanks @gumran! Added Mingyang (@lambertae) to review. |
| " r = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4)\n", | ||
| " t = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4) # mean -0.4, var 1\n", | ||
| "\n", | ||
| " # set r = t for 75% of the batch\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the Mean Flows paper says "The ratio of sampling r != t is 75%". In which case, shouldn't r = t 25% of the time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oddly, even with this wording, the official JAX implementation of Mean Flows uses 0.75 in the corresponding code: https://github.com/Gsunshine/meanflow/blob/8304ad42c9d955de6c9fb3b5a1f67bffefa632b7/meanflow.py#L72.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then their PyTorch code uses 1 - 0.75 instead: https://github.com/Gsunshine/py-meanflow/blob/1f6d72d94247c8fdeb89489acca1a8007a6baf6c/meanflow/models/time_sampler.py#L53
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please see table 1a in the paper.
examples/2d_mean_flow.ipynb
Outdated
| " t = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4) # mean -0.4, var 1\n", | ||
| "\n", | ||
| " # set r = t for 75% of the batch\n", | ||
| " mask = torch.randperm(batch_size)[int(batch_size * 0.75)]\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be [:int(batch_size * 0.75)] to permute more than a single (random-selected) batch element? Maybe this is related to the performance issues you noticed in your example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for the late response.
yes, thank you. just corrected it, and it does visibly improve performance somewhat.
A small implementation of MeanFlows based on the original paper.
The notebook is an adjusted version of the
2d_flow_matching.ipynbnotebook. We double the hidden layer size, the number of hidden layers in the MLP and the number of training iterations. The learning rate is reduced to 3e-4.The training algorithm is based on Algorithm 1 from the paper, using
torch.func.jvpto predict the average velocity. We use adaptive loss weighting like the authors do with p = 1 (see their Table 1d), and we set 75% ofrandtpairs to be equal, which corresponds to vanilla flow matching (see Table 1a). We also add gradient clipping.We use the same convention as in the paper:
x_0is an image andx_1is noise. To sample using MeanFlows, we subtract from noise the average velocity from 0 to 1 as predicted by the neural net.The one-step sampling performance shown in the notebook seems to be about as good as it gets with the given MLP architecture. We can further double the abovementioned hyperparameters, but so far it has been problematic with my compute restrictions.