-
Notifications
You must be signed in to change notification settings - Fork 877
Fine-Tuning FLUX.1-dev on Consumer Hardware blogpost #2888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
2f00e10
6b2f771
d8edaac
d7f9b7e
123387d
1f603a0
e62bf9e
79694bf
2d55e03
d5c0044
58a1c5a
f8a276a
e82632a
4b1dc0c
d72fb73
0278e7c
71093ca
ae73a65
cf4c4fc
d14f6dd
23ae532
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Fine-Tuning FLUX.1-dev with QLoRA | ||
|
||
In our previous post, "[Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization)", we dived into how various quantization techniques can shrink diffusion models like FLUX.1-dev, making them significantly more accessible for *inference* without drastically compromising performance. We saw how `bitsandbytes`, `torchao`, and others reduce memory footprints for generating images. | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Now, we tackle **efficiently *fine-tuning* these models.** This post will guide you through fine-tuning FLUX.1-dev using QLoRA with the Hugging Face `diffusers` library. We'll showcase results from an NVIDIA RTX 4090. | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's emphasize on the VRAM that the user needs ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ToC would be nice for those who know basics and want to skip to the gist of it |
||
## Why Not Just Full Fine-Tuning? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of making bulleted points, I think we could do it in short paragraphs to convey the main points. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree on this, also since specifically here, I think it's intuitive why not fully fine tuning so not necessary to have a dedicated title here imo |
||
|
||
`black-forest-labs/FLUX.1-dev`, for instance, requires over 31GB in BF16 for inference alone. | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**Full Fine-Tuning:** This traditional method updates all model weights. | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* **Pros:** Potential for the highest task-specific quality. | ||
* **Cons:** For FLUX.1-dev, this would demand immense VRAM (multiple high-end GPUs), putting it out of reach for most individual users. | ||
|
||
**LoRA (Low-Rank Adaptation):** LoRA freezes the pre-trained weights and injects small, trainable "adapter" layers. | ||
* **Pros:** Massively reduces trainable parameters, saving VRAM during training and resulting in small adapter checkpoints. | ||
* **Cons (for base model memory):** The full-precision base model still needs to be loaded, which, for FLUX.1-dev, is still a hefty VRAM requirement even if fewer parameters are being updated. | ||
|
||
**QLoRA: The Efficiency Powerhouse:** QLoRA enhances LoRA by: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider providing the QLoRA paper link as a reference. |
||
1. Loading the pre-trained base model in a quantized format (typically 4-bit via `bitsandbytes`), drastically cutting the base model's memory footprint. | ||
2. Training LoRA adapters (usually in FP16/BF16) on top of this quantized base. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add some link to relevant docs if they want more information on lora and qlora. Adding images would be also nice on how lora and qlora works would be nice |
||
|
||
This allows fine-tuning of very large models on consumer-grade hardware or more accessible cloud GPUs. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add somewhere a reminder of the architecture of flux and explain what we are trying to train ( in our case, we will only train the transformer model) |
||
## Dataset | ||
|
||
We aimed to fine-tune `black-forest-labs/FLUX.1-dev` to adopt the artistic style of Alphonse Mucha, using a small [dataset](https://huggingface.co/datasets/derekl35/alphonse-mucha-style). | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
<!-- (maybe use different dataset) --> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we get more datasets, we can just publish a collection with multiple loras There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could include the LoRA I used in https://github.com/huggingface/diffusers/tree/main/examples/research_projects/flux_lora_quantization. Maybe @linoytsaban can share a few more interesting datasets. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, here are a few: |
||
|
||
## QLoRA Fine-tuning FLUX.1-dev with `diffusers` | ||
|
||
We used a `diffusers` training script (very slightly modified from https://github.com/huggingface/diffusers/tree/main/examples/research_projects/flux_lora_quantization) designed for DreamBooth-style LoRA fine-tuning of FLUX models. Let's examine the crucial parts for QLoRA and memory efficiency: | ||
|
||
**Understanding the Key Optimization Techniques:** | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**LoRA (Low-Rank Adaptation) Deep Dive:** | ||
LoRA works by decomposing weight updates into low-rank matrices. Instead of updating the full weight matrix $$W$$, LoRA learns two smaller matrices $$A$$ and $$B$$ such that the update is $$\Delta W = BA$$, where $$A \in \mathbb{R}^{r \times k}$$ and $$B \in \mathbb{R}^{d \times r}$$. The rank $$r$$ is typically much smaller than the original dimensions, drastically reducing trainable parameters. LoRA $$\alpha$$ is a scaling factor for the LoRA activations, often set to the same value as the $$r$$ or a multiple of it. It helps balance the influence of the pre-trained model and the LoRA adapter. | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**8-bit Optimizer (AdamW):** | ||
Standard AdamW optimizer maintains first and second moment estimates for each parameter in FP32, consuming significant memory. The 8-bit AdamW uses block-wise quantization to store optimizer states in 8-bit precision while maintaining training stability. This technique can reduce optimizer memory usage by ~75% compared to standard FP32 AdamW. | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**Gradient Checkpointing:** | ||
During forward pass, intermediate activations are typically stored for backward pass gradient computation. Gradient checkpointing trades computation for memory by only storing certain "checkpoint" activations and recomputing others during backpropagation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't see that you added this here. Maybe add the image here + add links.
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
<!-- maybe explain cache latents --> | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice indeed ! |
||
|
||
**Setting up 4-bit Quantization (`BitsAndBytesConfig`):** | ||
|
||
This section demonstrates the QLoRA configuration for the base model: | ||
```python | ||
# Determine compute dtype based on mixed precision | ||
bnb_4bit_compute_dtype = torch.float32 | ||
if args.mixed_precision == "fp16": | ||
bnb_4bit_compute_dtype = torch.float16 | ||
elif args.mixed_precision == "bf16": | ||
bnb_4bit_compute_dtype = torch.bfloat16 | ||
|
||
nf4_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, | ||
) | ||
|
||
transformer = FluxTransformer2DModel.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
subfolder="transformer", | ||
quantization_config=nf4_config, | ||
torch_dtype=bnb_4bit_compute_dtype, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since the model weights are in bfloat16, let's do the training in that specific dtype no ? cc @sayakpaul There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. |
||
# Prepare model for k-bit training | ||
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) | ||
# Gradient checkpointing is enabled later via transformer.enable_gradient_checkpointing() if arg is set | ||
``` | ||
|
||
**Defining LoRA Configuration (`LoraConfig`):** | ||
Adapters are added to the quantized transformer: | ||
```python | ||
transformer_lora_config = LoraConfig( | ||
r=args.rank, | ||
lora_alpha=args.rank, | ||
init_lora_weights="gaussian", | ||
target_modules=["to_k", "to_q", "to_v", "to_out.0"], # FLUX attention blocks | ||
) | ||
transformer.add_adapter(transformer_lora_config) | ||
``` | ||
Only these LoRA parameters become trainable. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could put the nice prompt when you put the adapter and it shows number of trainable params vs total params here imo :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added total parameters and trainable parameters, not sure if that's what you meant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in PEFT, when you call I find it pretty cool :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohh. This one isn’t a PEFT model though, it’s a diffusers model, so I had to use num_parameters() to get the total and trainable counts. Figured that’d give a similar overview! |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also mention that text embeddings are also precomputed in the script. So, in general, we should aim to have a small section on "precomputing embeddings" and including VAE latents and the text embeddings per sample in our dataset. |
||
### Setup & Results | ||
|
||
For this demonstration, we leveraged an NVIDIA RTX 4090 (24GB VRAM) to explore its performance. | ||
|
||
**Configuration for RTX 4090:** | ||
On our RTX 4090, we used a `train_batch_size` of 1, `gradient_accumulation_steps` of 4, `mixed_precision="fp16"`, `gradient_checkpointing=True`, `use_8bit_adam=True`, a LoRA `rank` of 4, and resolution of 512x768. Latents were cached with `cache_latents=True`. | ||
|
||
**Memory Footprint (RTX 4090):** | ||
* **QLoRA:** Peak VRAM usage for QLoRA fine-tuning was approximately 9GB. | ||
* **FP16 LoRA:** Running standard LoRA (with the base FLUX.1-dev in FP16) on the same setup consumed 26 GB VRAM. | ||
* **FP16 full finetuning:** An estimate would be ~120 GB VRAM with no memory optimizations. | ||
|
||
|
||
**Training Time (RTX 4090):** | ||
Fine-tuning for 700 steps on the Alphonse Mucha dataset took approximately 41 minutes on the RTX 4090. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mention the resolution, batch size as well. |
||
|
||
**Output Quality:** | ||
The ultimate measure is the generated art. Here are samples from our QLoRA fine-tuned model on the `derekl35/alphonse-mucha-style` dataset: | ||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
base model: | ||
 | ||
|
||
QLoRA fine-tuned: | ||
 | ||
*Prompts: (left to right)* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to do a section about inference after we trained the loras. Of course you can just load back the lora but it would be nice to let the user now that he can also merge the loras into the base model for efficient inference. cc @sayakpaul do ppl actually merge the loras or not in practice ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It depends on use-cases and trade-offs. Sometimes people prefer merging to save VRAM, sometimes they don't to be able to experiment with different LoRAs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense ! let's clarify this in this blogpost |
||
|
||
*Serene raven-haired woman, moonlit lilies, swirling botanicals, alphonse mucha style* | ||
|
||
*a puppy in a pond, alphonse mucha style* | ||
|
||
*Ornate fox with a collar of autumn leaves and berries, amidst a tapestry of forest foliage, alphonse mucha style* | ||
|
||
The fine-tuned model nicely captured Mucha's iconic art nouveau style, evident in the decorative motifs and distinct color palette. The QLoRA process maintained excellent fidelity while learning the new style. | ||
|
||
**Colab Adaptability:** | ||
<!-- [add a section talking about / change above to be focused on running in google colab] --> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also shed light into
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. colab would be nice so that users feel like it is easier to reproduce what you did. Would be also nice to have a command to run the script directly with this specific dataset. Feel free to modify the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, T4 Colab would be terribly slow. |
||
|
||
## Conclusion | ||
|
||
QLoRA, coupled with the `diffusers` library, significantly democratizes the ability to customize state-of-the-art models like FLUX.1-dev. As demonstrated on an RTX 4090, efficient fine-tuning is well within reach, yielding high-quality stylistic adaptations. Importantly, these techniques are adaptable, paving the way for users on more constrained hardware, like Google Colab, to also participate. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. call-to-action with would be nice so that the blog converts to more models on Hub 🙏🏻 |
||
|
||
<!-- [Maybe add a link to trained LoRA adapter on Hugging Face Hub.] --> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to create a collection of adapters that you trained using this script |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe precise that users only consumer hardware ?
Fine-tuning FLUX.1-dev on consumer hardware with QLoRA