We provided an environment file for pip install.
pip install -r requirements.txtThe pre-trained checkpoints can be downloaded through belowing links
| Name | Specification | Link |
|---|---|---|
| 512 Base | Dual-diffusion pretrained, can do generation and caption | model |
| 512 SFT | SFT on LLaVA data, can do generation and vqa | model |
After downloading the checkpoints, check the Jupyter notebook notebooks/demo.ipynb for example usage.
Minimal working example:
from sd3_modules.dual_diff_pipeline import DualDiffSD3Pipeline
dual_diff_pipe = DualDiffSD3Pipeline.from_pretrained("./pretrained_models/dual_diff_sd3_512_base", torch_dtype=torch.bfloat16).to('cuda')
imgs = dual_diff_pipe(
prompt="A gourmet hamburger set on a rustic wooden table. The burger is made with a perfectly grilled, juicy beef patty topped with melted gourmet cheese, crispy bacon, fresh lettuce, ripe tomatoes, and caramelized onions.",
height=512,
width=512,
num_images_per_prompt=1)We support two kinds of image-text data, wrapped webdataset (.tar) data and unwrapped data. For unwrapped data, a json file storing data information is needed. The meta data should be a list of dictionaries like below:
[
{
"image_path": "images/img1.jpg",
"ratio": 1.33,
"height": 600,
"width": 800,
"caption": "A sunny day in the park.",
"re_caption": "A bright, lively park scene."
},
{
"image_path": "images/img2.jpg",
"ratio": 0.75,
"height": 400,
"width": 300,
"caption": "A night sky full of stars.",
"re_caption": "The starry night illuminates the scene."
},
]Following dataset are used in our project:
| Name | Usage | Link |
|---|---|---|
| Datacomp-recap | Base pretraining | data |
| ShareGPT4V pretrain | T5 embedding alignment, text diffusion training | data |
| LAION aesthetic | Image diffusion training | data |
| MidJourney 1.1M | Image diffusion training | data |
| LLaVA 1.5 | Text SFT | data |
To train the model, a SD3-medium checkpoint is needed, which can be downloaded from here. In addition, we also have an aligned embedding that corresponds to the "mask token" in T5's volcabulary here.
The example configuration for training are provided under configs directory.
We use 32 H100 for the pretraining, 16 A100 for SFT.
- Dual-diffusion training on image-text data (fill the torchrun argument with your machine's setting):
export OMP_NUM_THREADS=8
precision=bf16
torchrun --nnodes $WORKER_NUM \
--node_rank $ID \
--nproc_per_node $WORKER_GPU \
--master_addr $WORKER_0_HOST \
--master_port $port \
train_dual_diffusion_sd3.py \
--config configs/dual_diff_pretrain.py \
--results_dir results/ \
--model_parallel_size 1 \
--data_parallel h_sdp \
--precision ${precision} --grad_precision fp32- Supervised fine-tune (text diffusion with prompt + image diffusion) on some visual-instruction dataset (we used LLaVA 1.5's):
export OMP_NUM_THREADS=8
precision=bf16
torchrun --nnodes $WORKER_NUM \
--node_rank $ID \
--nproc_per_node $WORKER_GPU \
--master_addr $WORKER_0_HOST \
--master_port $port \
train_dual_diffusion_sd3_sft.py \
--config configs/dual_diff_sft.py \
--results_dir results/ \
--model_parallel_size 1 \
--data_parallel h_sdp \
--precision ${precision} --grad_precision fp32 \
--resume_t5 ${t5_mask_emb_pth}The implementation of this project is inspired from the great codebase of PixArt, MDLM, Lumina-Next.
Our DiT backbone is finetuned from SD3-medium.
If you find this project useful, please kindly consider citing our work:
@misc{li2024dualdiffusionunifiedimage,
title={Dual Diffusion for Unified Image Generation and Understanding},
author={Zijie Li and Henry Li and Yichun Shi and Amir Barati Farimani and Yuval Kluger and Linjie Yang and Peng Wang},
year={2024},
eprint={2501.00289},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2501.00289},
}