MeanFlow

October 16, 2025 Β· View on GitHub

MeanFlow

😈 This repository offers an unofficial PyTorch implementation of the paper Mean Flows for One-step Generative Modeling, building upon Just-a-DiT and EzAudio.

πŸ’¬ Contributions and feedback are very welcome β€” feel free to open an issue or pull request if you spot something or have ideas!

πŸ› οΈ This codebase is kept as clean and minimal as possible for easier integration into your own projects β€” thus, frameworks like Wandb are intentionally excluded.

πŸ“’ News

Sorry, I’ve been busy with other projects lately and haven’t updated this repo to support more functions.

Recently, rcm released JVP in Triton, which is insane β€” now you can use Flash Attention + MeanFlow.

Examples

MNIST -- 10k training steps, 1-step sample result:

MNIST

MNIST -- 6k training steps, 1-step CFG (w=2.0) sample result:

MNIST-cfg

CIFAR-10 -- 200k training steps, 1-step CFG (w=2.0) sample result:

CIFAR-10-cfg

TODO

  • Implement basic training and inference
  • Enable multi-GPU training via πŸ€— Accelerate
  • Add support for Classifier-Free Guidance (CFG)
  • Integrate latent image representation support
  • Add tricks like improved CFG mentioned in Appendix

Known Issues (PyTorch)

  • jvp is incompatible with Flash Attention and likely also with Triton, Mamba, and similar libraries.
  • jvp significantly increases GPU memory usage, even when using torch.utils.checkpoint.
  • CFG is implemented implicitly, leading to some limitations:
    • The CFG scale is fixed at training time and cannot be adjusted during inference.
    • Negative prompts are not supported, such as "noise" or "low quality" commonly used in text-to-image diffusion models.

🌟 Like This Project?

If you find this repo helpful or interesting, consider dropping a ⭐ β€” it really helps and means a lot!