Skip to content

Support InstantStyle #7668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 22, 2024
Merged
68 changes: 58 additions & 10 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_transformers_available,
logging,
)
from .unet_loader_utils import _maybe_expand_lora_scales


if is_transformers_available():
Expand Down Expand Up @@ -228,27 +229,74 @@ def load_ip_adapter(
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)

def set_ip_adapter_scale(self, scale):
def set_ip_adapter_scale(self, scale_configs: Union[float, Dict, List[Union[float, Dict]]], default_scale=0.0):
"""
Sets the conditioning scale between text and image.
Set IP-Adapter scales per-transformer block. Input `scale_configs` could be a single config or a list of configs
for granular control over each IP-Adapter behavior. A config can be a float or a dictionary.

Example:

```py
pipeline.set_ip_adapter_scale(0.5)
# To use original IP-Adapter
scale_configs = 1.0
pipeline.set_ip_adapter_scale(scale_configs)

# To use style block only
scale_configs = {
"up": {
"block_0": [0.0, 1.0, 0.0]
},
}
pipeline.set_ip_adapter_scale(scale_configs)

# To use style+layout blocks
scale_configs = {
"down": {
"block_2": [0.0, 1.0]
},
"up": {
"block_0": [0.0, 1.0, 0.0]
},
}
pipeline.set_ip_adapter_scale(scale_configs)

# To use style and layout from 2 reference images
scale_configs = [
{
"down": {
"block_2": [0.0, 1.0]
}
},
{
"up": {
"block_0": [0.0, 1.0, 0.0]
}
}
]
pipeline.set_ip_adapter_scale(scale_configs)
```
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if not isinstance(scale_configs, list):
scale_configs = [scale_configs]
scale_configs = _maybe_expand_lora_scales(unet, scale_configs, default_scale=default_scale)

for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if not isinstance(scale, list):
scale = [scale] * len(attn_processor.scale)
if len(attn_processor.scale) != len(scale):
if len(scale_configs)>1 and len(scale_configs)!=len(attn_processor.scale):
raise ValueError(
f"`scale` should be a list of same length as the number if ip-adapters "
f"Expected {len(attn_processor.scale)} but got {len(scale)}."
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
attn_processor.scale = scale
elif len(scale_configs)==1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
if isinstance(scale_config, dict):
for key, scale in scale_config.items():
if attn_name.startswith(key):
attn_processor.scale[i] = scale
else:
attn_processor.scale[i] = scale_config

def unload_ip_adapter(self):
"""
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/loaders/unet_loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _translate_into_actual_layer_name(name):
return ".".join((updown, block, attn))


def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]):
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0):
blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
Expand All @@ -47,7 +47,7 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[

expanded_weight_scales = [
_maybe_expand_lora_scales_for_one_adapter(
weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict()
weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict(), default_scale=default_scale
)
for weight_for_adapter in weight_scales
]
Expand All @@ -60,6 +60,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
default_scale: float=1.0,
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.
Expand Down Expand Up @@ -108,21 +109,27 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales = copy.deepcopy(scales)

if "mid" not in scales:
scales["mid"] = 1
scales["mid"] = default_scale

for updown in ["up", "down"]:
if updown not in scales:
scales[updown] = 1
scales[updown] = default_scale

# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if not isinstance(scales[updown], dict):
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}

# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
# eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
# set not assigned blocks to default scale
if block not in scales[updown]:
scales[updown][block] = default_scale
if not isinstance(scales[updown][block], list):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
else:
assert len(scales[updown][block]) == transformer_per_block[updown], \
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."

# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]:
Expand Down
140 changes: 71 additions & 69 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,44 +2229,45 @@ def __call__(
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
if mask is not None:
if not isinstance(scale, list):
scale = [scale]
if scale > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale can be a list when using image masking, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your careful inspect! This issues is resolved by adding an additional conditional statement.

if mask is not None:
if not isinstance(scale, list):
scale = [scale]

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down Expand Up @@ -2439,57 +2440,58 @@ def __call__(
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
if mask is not None:
if not isinstance(scale, list):
scale = [scale]
if scale > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as before: when using image masking, the user can provide a list of masks for each IP Adapter; each mask will have its own scale

code would break here, as you cannot compare a (possible) list with an integer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't test this, this part doesn't have a unit test? If you need help to test this, just tell me and I can help too.

To be honest I haven't tested any of the multiple images with multiple masks and scales per image but I plan to do it soon to make another guide with them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is just a slow test (test_ip_adapter_multiple_masks_one_adapter in tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py) for the combination multiple images + masks + scales for IP Adapter, unfortunately there are no fast tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the test case we are talking about?

import torch
from transformers import CLIPVisionModelWithProjection

from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import load_image, logging
from diffusers.utils.logging import set_verbosity


set_verbosity(logging.ERROR)  # to not show cross_attention_kwargs...by AttnProcessor2_0 warnings


# load & process masks
composition_mask = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/1024_whole_mask.png"
)
female_mask = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125641_mask.png"
)
male_mask = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125344_mask.png"
)
background_mask = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_6_20240321130722_mask.png"
)
processor = IPAdapterMaskProcessor()
masks1 = processor.preprocess([composition_mask], height=1024, width=1024)
masks2 = processor.preprocess([female_mask, male_mask, background_mask], height=1024, width=1024)
masks2 = masks2.reshape(1, masks2.shape[0], masks2.shape[2], masks2.shape[3])  # output -> (1, 3, 1024, 1024)
masks = [masks1, masks2]

# load images
ip_composition_image = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125152.png"
)
ip_female_style = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125625.png"
)
ip_male_style = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125329.png"
)
ip_background = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321130643.png"
)


image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16
).to("cuda")

pipeline = StableDiffusionXLPipeline.from_pretrained(
    "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16"
).to("cuda")

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler.config.use_karras_sigmas = True

pipeline.load_ip_adapter(
    ["ostris/ip-composition-adapter", "h94/IP-Adapter"],
    subfolder=["", "sdxl_models"],
    weight_name=[
        "ip_plus_composition_sdxl.safetensors",
        "ip-adapter_sdxl_vit-h.safetensors",
    ],
    image_encoder_folder=None,
)
pipeline.set_ip_adapter_scale([1.0, [0.75, 0.75, 0.3]])



prompt = "high quality, cinematic photo, cinemascope, 35mm, film grain, highly detailed"
negative_prompt = "anime, cartoon"


image = pipeline(
    prompt=prompt,
    negative_prompt="",
    ip_adapter_image=[ip_composition_image, [ip_female_style, ip_male_style, ip_background]],
    cross_attention_kwargs={"ip_adapter_masks": masks},
    guidance_scale=6.5,
    num_inference_steps=25,
).images[0]

image.save("yiyi_test_mask_multi_out.png")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your test case, I will test with it and update the test codes when the result is as expected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't it be better to have a fast test for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the test case we are talking about?

exactly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't it be better to have a fast test for this?

I can create a new PR to add a fast test and update documentation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should work. Thanks as always.

if mask is not None:
if not isinstance(scale, list):
scale = [scale]

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention(
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down