Skip to content

Commit b5c8b55

Browse files
Move IP Adapter Face ID to core (#7186)
* Switch to peft and multi proj layers * Move Face ID loading and inference to core --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent e23c27e commit b5c8b55

File tree

10 files changed

+592
-375
lines changed

10 files changed

+592
-375
lines changed

docs/source/en/using-diffusers/ip_adapter.md

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma
362362

363363
### Face model
364364

365-
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces:
365+
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository:
366366

367367
* [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds
368368
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces
369369

370-
> [!TIP]
371-
>
372-
> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
370+
Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository.
373371

374372
For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.
375373

@@ -411,6 +409,56 @@ image
411409
</div>
412410
</div>
413411

412+
To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`.
413+
414+
```py
415+
import torch
416+
from diffusers import StableDiffusionPipeline, DDIMScheduler
417+
from diffusers.utils import load_image
418+
from insightface.app import FaceAnalysis
419+
420+
pipeline = StableDiffusionPipeline.from_pretrained(
421+
"runwayml/stable-diffusion-v1-5",
422+
torch_dtype=torch.float16,
423+
).to("cuda")
424+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
425+
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None)
426+
pipeline.set_ip_adapter_scale(0.6)
427+
428+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")
429+
430+
ref_images_embeds = []
431+
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
432+
app.prepare(ctx_id=0, det_size=(640, 640))
433+
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
434+
faces = app.get(image)
435+
image = torch.from_numpy(faces[0].normed_embedding)
436+
ref_images_embeds.append(image.unsqueeze(0))
437+
ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0)
438+
neg_ref_images_embeds = torch.zeros_like(ref_images_embeds)
439+
id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda"))
440+
441+
generator = torch.Generator(device="cpu").manual_seed(42)
442+
443+
images = pipeline(
444+
prompt="A photo of a girl",
445+
ip_adapter_image_embeds=[id_embeds],
446+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
447+
num_inference_steps=20, num_images_per_prompt=1,
448+
generator=generator
449+
).images
450+
```
451+
452+
Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers.
453+
454+
```py
455+
clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0]
456+
457+
pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16)
458+
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2
459+
```
460+
461+
414462
### Multi IP-Adapter
415463

416464
More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style.

docs/source/en/using-diffusers/loading_adapters.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,40 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
320320

321321
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
322322
```
323+
324+
### IP-Adapter Face ID models
325+
326+
The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency.
327+
You need to install `insightface` and all its requirements to use these models.
328+
329+
<Tip warning={true}>
330+
As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use.
331+
</Tip>
332+
333+
```py
334+
pipeline = AutoPipelineForText2Image.from_pretrained(
335+
"stabilityai/stable-diffusion-xl-base-1.0",
336+
torch_dtype=torch.float16
337+
).to("cuda")
338+
339+
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None)
340+
```
341+
342+
If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism.
343+
344+
```py
345+
from transformers import CLIPVisionModelWithProjection
346+
347+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
348+
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
349+
torch_dtype=torch.float16,
350+
)
351+
352+
pipeline = AutoPipelineForText2Image.from_pretrained(
353+
"runwayml/stable-diffusion-v1-5",
354+
image_encoder=image_encoder,
355+
torch_dtype=torch.float16
356+
).to("cuda")
357+
358+
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin")
359+
```

examples/community/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif")
38193819
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
38203820
You need to install `insightface` and all its requirements to use this model.
38213821
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
3822-
You have to disable PEFT BACKEND in order to load weights.
38233822
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
38243823

38253824
```py
38263825
import diffusers
3827-
diffusers.utils.USE_PEFT_BACKEND = False
38283826
import torch
38293827
from diffusers.utils import load_image
38303828
import cv2

0 commit comments

Comments
 (0)