diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md
index dc64b2548529..e3c4178a1507 100644
--- a/docs/source/en/using-diffusers/ip_adapter.md
+++ b/docs/source/en/using-diffusers/ip_adapter.md
@@ -640,3 +640,87 @@ image
+
+### Style & layout control
+
+[InstantStyle](https://arxiv.org/abs/2404.02733) is a plug-and-play method on top of IP-Adapter, which disentangles style and layout from image prompt to control image generation. This is achieved by only inserting IP-Adapters to some specific part of the model.
+
+By default IP-Adapters are inserted to all layers of the model. Use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method with a dictionary to assign scales to IP-Adapter at different layers.
+
+```py
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import load_image
+import torch
+
+pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
+
+scale = {
+ "down": {"block_2": [0.0, 1.0]},
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+}
+pipeline.set_ip_adapter_scale(scale)
+```
+
+This will activate IP-Adapter at the second layer in the model's down-part block 2 and up-part block 0. The former is the layer where IP-Adapter injects layout information and the latter injects style. Inserting IP-Adapter to these two layers you can generate images following the style and layout of image prompt, but with contents more aligned to text prompt.
+
+```py
+style_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
+
+generator = torch.Generator(device="cpu").manual_seed(42)
+image = pipeline(
+ prompt="a cat, masterpiece, best quality, high quality",
+ image=style_image,
+ negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
+ guidance_scale=5,
+ num_inference_steps=30,
+ generator=generator,
+).images[0]
+image
+```
+
+
+
+

+
IP-Adapter image
+
+
+

+
generated image
+
+
+
+In contrast, inserting IP-Adapter to all layers will often generate images that overly focus on image prompt and diminish diversity.
+
+Activate IP-Adapter only in the style layer and then call the pipeline again.
+
+```py
+scale = {
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+}
+pipeline.set_ip_adapter_scale(scale)
+
+generator = torch.Generator(device="cpu").manual_seed(42)
+image = pipeline(
+ prompt="a cat, masterpiece, best quality, high quality",
+ image=style_image,
+ negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
+ guidance_scale=5,
+ num_inference_steps=30,
+ generator=generator,
+).images[0]
+image
+```
+
+
+
+

+
IP-Adapter only in style layer
+
+
+

+
IP-Adapter in all layers
+
+
+
+Note that you don't have to specify all layers in the dictionary. Those not included in the dictionary will be set to scale 0 which means disable IP-Adapter by default.
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index fdddc382212f..cb158a4bc194 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -28,6 +28,7 @@
is_transformers_available,
logging,
)
+from .unet_loader_utils import _maybe_expand_lora_scales
if is_transformers_available():
@@ -243,25 +244,55 @@ def load_ip_adapter(
def set_ip_adapter_scale(self, scale):
"""
- Sets the conditioning scale between text and image.
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
Example:
```py
- pipeline.set_ip_adapter_scale(0.5)
+ # To use original IP-Adapter
+ scale = 1.0
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style block only
+ scale = {
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style+layout blocks
+ scale = {
+ "down": {"block_2": [0.0, 1.0]},
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style and layout from 2 reference images
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
+ pipeline.set_ip_adapter_scale(scales)
```
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
- for attn_processor in unet.attn_processors.values():
+ if not isinstance(scale, list):
+ scale = [scale]
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
+
+ for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
- if not isinstance(scale, list):
- scale = [scale] * len(attn_processor.scale)
- if len(attn_processor.scale) != len(scale):
+ if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
- f"`scale` should be a list of same length as the number if ip-adapters "
- f"Expected {len(attn_processor.scale)} but got {len(scale)}."
+ f"Cannot assign {len(scale_configs)} scale_configs to "
+ f"{len(attn_processor.scale)} IP-Adapter."
)
- attn_processor.scale = scale
+ elif len(scale_configs) == 1:
+ scale_configs = scale_configs * len(attn_processor.scale)
+ for i, scale_config in enumerate(scale_configs):
+ if isinstance(scale_config, dict):
+ for k, s in scale_config.items():
+ if attn_name.startswith(k):
+ attn_processor.scale[i] = s
+ else:
+ attn_processor.scale[i] = scale_config
def unload_ip_adapter(self):
"""
diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py
index 3ee4a96fad0a..8f202ed4d44b 100644
--- a/src/diffusers/loaders/unet_loader_utils.py
+++ b/src/diffusers/loaders/unet_loader_utils.py
@@ -38,7 +38,9 @@ def _translate_into_actual_layer_name(name):
return ".".join((updown, block, attn))
-def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]):
+def _maybe_expand_lora_scales(
+ unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
+):
blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
@@ -47,7 +49,11 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[
expanded_weight_scales = [
_maybe_expand_lora_scales_for_one_adapter(
- weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict()
+ weight_for_adapter,
+ blocks_with_transformer,
+ transformer_per_block,
+ unet.state_dict(),
+ default_scale=default_scale,
)
for weight_for_adapter in weight_scales
]
@@ -60,6 +66,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
+ default_scale: float = 1.0,
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.
@@ -108,21 +115,36 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales = copy.deepcopy(scales)
if "mid" not in scales:
- scales["mid"] = 1
+ scales["mid"] = default_scale
+ elif isinstance(scales["mid"], list):
+ if len(scales["mid"]) == 1:
+ scales["mid"] = scales["mid"][0]
+ else:
+ raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
for updown in ["up", "down"]:
if updown not in scales:
- scales[updown] = 1
+ scales[updown] = default_scale
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if not isinstance(scales[updown], dict):
- scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}
+ scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
- # eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
+ # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
+ # set not assigned blocks to default scale
+ if block not in scales[updown]:
+ scales[updown][block] = default_scale
if not isinstance(scales[updown][block], list):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
+ elif len(scales[updown][block]) == 1:
+ # a list specifying scale to each masked IP input
+ scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
+ elif len(scales[updown][block]) != transformer_per_block[updown]:
+ raise ValueError(
+ f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
+ )
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]:
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 237e8236caf4..429807989296 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2229,44 +2229,51 @@ def __call__(
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
- if mask is not None:
- if not isinstance(scale, list):
- scale = [scale]
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
- current_num_images = mask.shape[1]
- for i in range(current_num_images):
- ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
- ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
-
- mask_downsample = IPAdapterMaskProcessor.downsample(
- mask[:, i, :, :],
- batch_size,
- _current_ip_hidden_states.shape[1],
- _current_ip_hidden_states.shape[2],
- )
-
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
-
- hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
- else:
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
-
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
-
- hidden_states = hidden_states + scale * current_ip_hidden_states
+ hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@@ -2439,57 +2446,64 @@ def __call__(
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
- if mask is not None:
- if not isinstance(scale, list):
- scale = [scale]
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
- current_num_images = mask.shape[1]
- for i in range(current_num_images):
- ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
- ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
- _current_ip_hidden_states = F.scaled_dot_product_attention(
+ current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
- _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
- _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
-
- mask_downsample = IPAdapterMaskProcessor.downsample(
- mask[:, i, :, :],
- batch_size,
- _current_ip_hidden_states.shape[1],
- _current_ip_hidden_states.shape[2],
- )
-
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
- hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
- else:
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- current_ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
- hidden_states = hidden_states + scale * current_ip_hidden_states
+ hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index ef70baa05f19..3a5ff03e564a 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -73,7 +73,9 @@ def get_image_processor(self, repo_id):
image_processor = CLIPImageProcessor.from_pretrained(repo_id)
return image_processor
- def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False):
+ def get_dummy_inputs(
+ self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False, for_instant_style=False
+ ):
image = load_image(
"https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png"
)
@@ -126,6 +128,40 @@ def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_s
}
)
+ elif for_instant_style:
+ composition_mask = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/1024_whole_mask.png"
+ )
+ female_mask = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125641_mask.png"
+ )
+ male_mask = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125344_mask.png"
+ )
+ background_mask = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_6_20240321130722_mask.png"
+ )
+ ip_composition_image = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125152.png"
+ )
+ ip_female_style = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125625.png"
+ )
+ ip_male_style = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125329.png"
+ )
+ ip_background = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321130643.png"
+ )
+ input_kwargs.update(
+ {
+ "ip_adapter_image": [ip_composition_image, [ip_female_style, ip_male_style, ip_background]],
+ "cross_attention_kwargs": {
+ "ip_adapter_masks": [[composition_mask], [female_mask, male_mask, background_mask]]
+ },
+ }
+ )
+
return input_kwargs
@@ -575,6 +611,48 @@ def test_ip_adapter_multiple_masks(self):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
+ def test_instant_style_multiple_masks(self):
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16"
+ ).to("cuda")
+ pipeline.enable_model_cpu_offload()
+
+ pipeline.load_ip_adapter(
+ ["ostris/ip-composition-adapter", "h94/IP-Adapter"],
+ subfolder=["", "sdxl_models"],
+ weight_name=[
+ "ip_plus_composition_sdxl.safetensors",
+ "ip-adapter_sdxl_vit-h.safetensors",
+ ],
+ image_encoder_folder=None,
+ )
+ scale_1 = {
+ "down": [[0.0, 0.0, 1.0]],
+ "mid": [[0.0, 0.0, 1.0]],
+ "up": {"block_0": [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], "block_1": [[0.0, 0.0, 1.0]]},
+ }
+ pipeline.set_ip_adapter_scale([1.0, scale_1])
+
+ inputs = self.get_dummy_inputs(for_instant_style=True)
+ processor = IPAdapterMaskProcessor()
+ masks1 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][0]
+ masks2 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][1]
+ masks1 = processor.preprocess(masks1, height=1024, width=1024)
+ masks2 = processor.preprocess(masks2, height=1024, width=1024)
+ masks2 = masks2.reshape(1, masks2.shape[0], masks2.shape[2], masks2.shape[3])
+ inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks1, masks2]
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+ expected_slice = np.array(
+ [0.23551631, 0.20476806, 0.14099443, 0.0, 0.07675594, 0.05672678, 0.0, 0.0, 0.02099729]
+ )
+
+ max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
+ assert max_diff < 5e-4
+
def test_ip_adapter_multiple_masks_one_adapter(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionXLPipeline.from_pretrained(