diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index dc64b2548529..e3c4178a1507 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -640,3 +640,87 @@ image
   
+ +### Style & layout control + +[InstantStyle](https://arxiv.org/abs/2404.02733) is a plug-and-play method on top of IP-Adapter, which disentangles style and layout from image prompt to control image generation. This is achieved by only inserting IP-Adapters to some specific part of the model. + +By default IP-Adapters are inserted to all layers of the model. Use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method with a dictionary to assign scales to IP-Adapter at different layers. + +```py +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import load_image +import torch + +pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda") +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + +scale = { + "down": {"block_2": [0.0, 1.0]}, + "up": {"block_0": [0.0, 1.0, 0.0]}, +} +pipeline.set_ip_adapter_scale(scale) +``` + +This will activate IP-Adapter at the second layer in the model's down-part block 2 and up-part block 0. The former is the layer where IP-Adapter injects layout information and the latter injects style. Inserting IP-Adapter to these two layers you can generate images following the style and layout of image prompt, but with contents more aligned to text prompt. + +```py +style_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg") + +generator = torch.Generator(device="cpu").manual_seed(42) +image = pipeline( + prompt="a cat, masterpiece, best quality, high quality", + image=style_image, + negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", + guidance_scale=5, + num_inference_steps=30, + generator=generator, +).images[0] +image +``` + +
+
+ +
IP-Adapter image
+
+
+ +
generated image
+
+
+ +In contrast, inserting IP-Adapter to all layers will often generate images that overly focus on image prompt and diminish diversity. + +Activate IP-Adapter only in the style layer and then call the pipeline again. + +```py +scale = { + "up": {"block_0": [0.0, 1.0, 0.0]}, +} +pipeline.set_ip_adapter_scale(scale) + +generator = torch.Generator(device="cpu").manual_seed(42) +image = pipeline( + prompt="a cat, masterpiece, best quality, high quality", + image=style_image, + negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", + guidance_scale=5, + num_inference_steps=30, + generator=generator, +).images[0] +image +``` + +
+
+ +
IP-Adapter only in style layer
+
+
+ +
IP-Adapter in all layers
+
+
+ +Note that you don't have to specify all layers in the dictionary. Those not included in the dictionary will be set to scale 0 which means disable IP-Adapter by default. diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index fdddc382212f..cb158a4bc194 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -28,6 +28,7 @@ is_transformers_available, logging, ) +from .unet_loader_utils import _maybe_expand_lora_scales if is_transformers_available(): @@ -243,25 +244,55 @@ def load_ip_adapter( def set_ip_adapter_scale(self, scale): """ - Sets the conditioning scale between text and image. + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a dictionary. Example: ```py - pipeline.set_ip_adapter_scale(0.5) + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + # To use style block only + scale = { + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style+layout blocks + scale = { + "down": {"block_2": [0.0, 1.0]}, + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style and layout from 2 reference images + scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] + pipeline.set_ip_adapter_scale(scales) ``` """ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for attn_processor in unet.attn_processors.values(): + if not isinstance(scale, list): + scale = [scale] + scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) + + for attn_name, attn_processor in unet.attn_processors.items(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): - if not isinstance(scale, list): - scale = [scale] * len(attn_processor.scale) - if len(attn_processor.scale) != len(scale): + if len(scale_configs) != len(attn_processor.scale): raise ValueError( - f"`scale` should be a list of same length as the number if ip-adapters " - f"Expected {len(attn_processor.scale)} but got {len(scale)}." + f"Cannot assign {len(scale_configs)} scale_configs to " + f"{len(attn_processor.scale)} IP-Adapter." ) - attn_processor.scale = scale + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + if isinstance(scale_config, dict): + for k, s in scale_config.items(): + if attn_name.startswith(k): + attn_processor.scale[i] = s + else: + attn_processor.scale[i] = scale_config def unload_ip_adapter(self): """ diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py index 3ee4a96fad0a..8f202ed4d44b 100644 --- a/src/diffusers/loaders/unet_loader_utils.py +++ b/src/diffusers/loaders/unet_loader_utils.py @@ -38,7 +38,9 @@ def _translate_into_actual_layer_name(name): return ".".join((updown, block, attn)) -def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]): +def _maybe_expand_lora_scales( + unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0 +): blocks_with_transformer = { "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], @@ -47,7 +49,11 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[ expanded_weight_scales = [ _maybe_expand_lora_scales_for_one_adapter( - weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict() + weight_for_adapter, + blocks_with_transformer, + transformer_per_block, + unet.state_dict(), + default_scale=default_scale, ) for weight_for_adapter in weight_scales ] @@ -60,6 +66,7 @@ def _maybe_expand_lora_scales_for_one_adapter( blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int], state_dict: None, + default_scale: float = 1.0, ): """ Expands the inputs into a more granular dictionary. See the example below for more details. @@ -108,21 +115,36 @@ def _maybe_expand_lora_scales_for_one_adapter( scales = copy.deepcopy(scales) if "mid" not in scales: - scales["mid"] = 1 + scales["mid"] = default_scale + elif isinstance(scales["mid"], list): + if len(scales["mid"]) == 1: + scales["mid"] = scales["mid"][0] + else: + raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.") for updown in ["up", "down"]: if updown not in scales: - scales[updown] = 1 + scales[updown] = default_scale # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}} if not isinstance(scales[updown], dict): - scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]} + scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]} - # eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}} + # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}} for i in blocks_with_transformer[updown]: block = f"block_{i}" + # set not assigned blocks to default scale + if block not in scales[updown]: + scales[updown][block] = default_scale if not isinstance(scales[updown][block], list): scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] + elif len(scales[updown][block]) == 1: + # a list specifying scale to each masked IP input + scales[updown][block] = scales[updown][block] * transformer_per_block[updown] + elif len(scales[updown][block]) != transformer_per_block[updown]: + raise ValueError( + f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}." + ) # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1} for i in blocks_with_transformer[updown]: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 237e8236caf4..429807989296 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2229,44 +2229,51 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - if mask is not None: - if not isinstance(scale, list): - scale = [scale] + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - current_num_images = mask.shape[1] - for i in range(current_num_images): - ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) - ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) - - mask_downsample = IPAdapterMaskProcessor.downsample( - mask[:, i, :, :], - batch_size, - _current_ip_hidden_states.shape[1], - _current_ip_hidden_states.shape[2], - ) - - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - - hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) - else: - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -2439,57 +2446,64 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - if mask is not None: - if not isinstance(scale, list): - scale = [scale] + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) - current_num_images = mask.shape[1] - for i in range(current_num_images): - ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) - ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - _current_ip_hidden_states = F.scaled_dot_product_attention( + current_ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) - _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) - - mask_downsample = IPAdapterMaskProcessor.downsample( - mask[:, i, :, :], - batch_size, - _current_ip_hidden_states.shape[1], - _current_ip_hidden_states.shape[2], - ) - - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) - else: - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index ef70baa05f19..3a5ff03e564a 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -73,7 +73,9 @@ def get_image_processor(self, repo_id): image_processor = CLIPImageProcessor.from_pretrained(repo_id) return image_processor - def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False): + def get_dummy_inputs( + self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False, for_instant_style=False + ): image = load_image( "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" ) @@ -126,6 +128,40 @@ def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_s } ) + elif for_instant_style: + composition_mask = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/1024_whole_mask.png" + ) + female_mask = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125641_mask.png" + ) + male_mask = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125344_mask.png" + ) + background_mask = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_6_20240321130722_mask.png" + ) + ip_composition_image = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125152.png" + ) + ip_female_style = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125625.png" + ) + ip_male_style = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125329.png" + ) + ip_background = load_image( + "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321130643.png" + ) + input_kwargs.update( + { + "ip_adapter_image": [ip_composition_image, [ip_female_style, ip_male_style, ip_background]], + "cross_attention_kwargs": { + "ip_adapter_masks": [[composition_mask], [female_mask, male_mask, background_mask]] + }, + } + ) + return input_kwargs @@ -575,6 +611,48 @@ def test_ip_adapter_multiple_masks(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 + def test_instant_style_multiple_masks(self): + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16 + ).to("cuda") + pipeline = StableDiffusionXLPipeline.from_pretrained( + "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16" + ).to("cuda") + pipeline.enable_model_cpu_offload() + + pipeline.load_ip_adapter( + ["ostris/ip-composition-adapter", "h94/IP-Adapter"], + subfolder=["", "sdxl_models"], + weight_name=[ + "ip_plus_composition_sdxl.safetensors", + "ip-adapter_sdxl_vit-h.safetensors", + ], + image_encoder_folder=None, + ) + scale_1 = { + "down": [[0.0, 0.0, 1.0]], + "mid": [[0.0, 0.0, 1.0]], + "up": {"block_0": [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], "block_1": [[0.0, 0.0, 1.0]]}, + } + pipeline.set_ip_adapter_scale([1.0, scale_1]) + + inputs = self.get_dummy_inputs(for_instant_style=True) + processor = IPAdapterMaskProcessor() + masks1 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][0] + masks2 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][1] + masks1 = processor.preprocess(masks1, height=1024, width=1024) + masks2 = processor.preprocess(masks2, height=1024, width=1024) + masks2 = masks2.reshape(1, masks2.shape[0], masks2.shape[2], masks2.shape[3]) + inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks1, masks2] + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.23551631, 0.20476806, 0.14099443, 0.0, 0.07675594, 0.05672678, 0.0, 0.0, 0.02099729] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 + def test_ip_adapter_multiple_masks_one_adapter(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionXLPipeline.from_pretrained(