Skip to content

Conversation

@gumran
Copy link

@gumran gumran commented Sep 17, 2025

A small implementation of MeanFlows based on the original paper.

The notebook is an adjusted version of the 2d_flow_matching.ipynb notebook. We double the hidden layer size, the number of hidden layers in the MLP and the number of training iterations. The learning rate is reduced to 3e-4.

The training algorithm is based on Algorithm 1 from the paper, using torch.func.jvp to predict the average velocity. We use adaptive loss weighting like the authors do with p = 1 (see their Table 1d), and we set 75% of r and t pairs to be equal, which corresponds to vanilla flow matching (see Table 1a). We also add gradient clipping.

We use the same convention as in the paper: x_0 is an image and x_1 is noise. To sample using MeanFlows, we subtract from noise the average velocity from 0 to 1 as predicted by the neural net.

The one-step sampling performance shown in the notebook seems to be about as good as it gets with the given MLP architecture. We can further double the abovementioned hyperparameters, but so far it has been problematic with my compute restrictions.

@meta-cla
Copy link

meta-cla bot commented Sep 17, 2025

Hi @gumran!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@meta-cla
Copy link

meta-cla bot commented Sep 17, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the cla signed label Sep 17, 2025
"# flow_matching\n",
"from flow_matching.path.scheduler import CondOTScheduler\n",
"from flow_matching.path import AffineProbPath\n",
"from flow_matching.solver import Solver, ODESolver\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These imports seem to be unused.

Copy link
Author

@gumran gumran Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the solver imports are unused.

" x_1 = torch.randn_like(x_0).to(device)\n",
"\n",
" # sample two time points from a logit-normal distribution\n",
" r = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4)\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't these x_1.shape[0] references for r and t be replaced with batch_size?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I echoed the original 2d_flow_matching.ipynb notebook, but you're right.

@itaigat itaigat requested a review from lambertae September 19, 2025 11:22
@itaigat
Copy link
Contributor

itaigat commented Sep 19, 2025

Thanks @gumran! Added Mingyang (@lambertae) to review.

" r = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4)\n",
" t = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4) # mean -0.4, var 1\n",
"\n",
" # set r = t for 75% of the batch\n",
Copy link

@amorehead amorehead Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the Mean Flows paper says "The ratio of sampling r != t is 75%". In which case, shouldn't r = t 25% of the time?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oddly, even with this wording, the official JAX implementation of Mean Flows uses 0.75 in the corresponding code: https://github.com/Gsunshine/meanflow/blob/8304ad42c9d955de6c9fb3b5a1f67bffefa632b7/meanflow.py#L72.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see table 1a in the paper.

" t = torch.sigmoid(torch.randn(x_1.shape[0], device=device) - 0.4) # mean -0.4, var 1\n",
"\n",
" # set r = t for 75% of the batch\n",
" mask = torch.randperm(batch_size)[int(batch_size * 0.75)]\n",
Copy link

@amorehead amorehead Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be [:int(batch_size * 0.75)] to permute more than a single (random-selected) batch element? Maybe this is related to the performance issues you noticed in your example.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the late response.

yes, thank you. just corrected it, and it does visibly improve performance somewhat.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants