-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Support InstantStyle #7668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support InstantStyle #7668
Changes from 2 commits
1f6b84a
ef9694c
24224a0
9d09a56
07a18fd
c45b1c7
cb0ade6
3f01c6d
d2d31e9
5d0bdfa
a98f498
6fc9a3a
1e550be
8540c1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2229,44 +2229,45 @@ def __call__( | |
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( | ||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks | ||
): | ||
if mask is not None: | ||
if not isinstance(scale, list): | ||
scale = [scale] | ||
if scale > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your careful inspect! This issues is resolved by adding an additional conditional statement. |
||
if mask is not None: | ||
if not isinstance(scale, list): | ||
scale = [scale] | ||
|
||
current_num_images = mask.shape[1] | ||
for i in range(current_num_images): | ||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) | ||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) | ||
|
||
ip_key = attn.head_to_batch_dim(ip_key) | ||
ip_value = attn.head_to_batch_dim(ip_value) | ||
|
||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) | ||
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) | ||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) | ||
|
||
mask_downsample = IPAdapterMaskProcessor.downsample( | ||
mask[:, i, :, :], | ||
batch_size, | ||
_current_ip_hidden_states.shape[1], | ||
_current_ip_hidden_states.shape[2], | ||
) | ||
|
||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) | ||
|
||
current_num_images = mask.shape[1] | ||
for i in range(current_num_images): | ||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) | ||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) | ||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) | ||
else: | ||
ip_key = to_k_ip(current_ip_hidden_states) | ||
ip_value = to_v_ip(current_ip_hidden_states) | ||
|
||
ip_key = attn.head_to_batch_dim(ip_key) | ||
ip_value = attn.head_to_batch_dim(ip_value) | ||
|
||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) | ||
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) | ||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) | ||
|
||
mask_downsample = IPAdapterMaskProcessor.downsample( | ||
mask[:, i, :, :], | ||
batch_size, | ||
_current_ip_hidden_states.shape[1], | ||
_current_ip_hidden_states.shape[2], | ||
) | ||
|
||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) | ||
|
||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) | ||
else: | ||
ip_key = to_k_ip(current_ip_hidden_states) | ||
ip_value = to_v_ip(current_ip_hidden_states) | ||
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) | ||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) | ||
|
||
ip_key = attn.head_to_batch_dim(ip_key) | ||
ip_value = attn.head_to_batch_dim(ip_value) | ||
|
||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) | ||
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) | ||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) | ||
|
||
hidden_states = hidden_states + scale * current_ip_hidden_states | ||
hidden_states = hidden_states + scale * current_ip_hidden_states | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) | ||
|
@@ -2439,57 +2440,58 @@ def __call__( | |
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( | ||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks | ||
): | ||
if mask is not None: | ||
if not isinstance(scale, list): | ||
scale = [scale] | ||
if scale > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as before: when using image masking, the user can provide a list of masks for each IP Adapter; each mask will have its own scale code would break here, as you cannot compare a (possible) list with an integer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't test this, this part doesn't have a unit test? If you need help to test this, just tell me and I can help too. To be honest I haven't tested any of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is just a slow test (test_ip_adapter_multiple_masks_one_adapter in tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py) for the combination There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the test case we are talking about? import torch
from transformers import CLIPVisionModelWithProjection
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import load_image, logging
from diffusers.utils.logging import set_verbosity
set_verbosity(logging.ERROR) # to not show cross_attention_kwargs...by AttnProcessor2_0 warnings
# load & process masks
composition_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/1024_whole_mask.png"
)
female_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125641_mask.png"
)
male_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125344_mask.png"
)
background_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_6_20240321130722_mask.png"
)
processor = IPAdapterMaskProcessor()
masks1 = processor.preprocess([composition_mask], height=1024, width=1024)
masks2 = processor.preprocess([female_mask, male_mask, background_mask], height=1024, width=1024)
masks2 = masks2.reshape(1, masks2.shape[0], masks2.shape[2], masks2.shape[3]) # output -> (1, 3, 1024, 1024)
masks = [masks1, masks2]
# load images
ip_composition_image = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125152.png"
)
ip_female_style = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125625.png"
)
ip_male_style = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125329.png"
)
ip_background = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321130643.png"
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16
).to("cuda")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16"
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler.config.use_karras_sigmas = True
pipeline.load_ip_adapter(
["ostris/ip-composition-adapter", "h94/IP-Adapter"],
subfolder=["", "sdxl_models"],
weight_name=[
"ip_plus_composition_sdxl.safetensors",
"ip-adapter_sdxl_vit-h.safetensors",
],
image_encoder_folder=None,
)
pipeline.set_ip_adapter_scale([1.0, [0.75, 0.75, 0.3]])
prompt = "high quality, cinematic photo, cinemascope, 35mm, film grain, highly detailed"
negative_prompt = "anime, cartoon"
image = pipeline(
prompt=prompt,
negative_prompt="",
ip_adapter_image=[ip_composition_image, [ip_female_style, ip_male_style, ip_background]],
cross_attention_kwargs={"ip_adapter_masks": masks},
guidance_scale=6.5,
num_inference_steps=25,
).images[0]
image.save("yiyi_test_mask_multi_out.png") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your test case, I will test with it and update the test codes when the result is as expected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't it be better to have a fast test for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
exactly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I can create a new PR to add a fast test and update documentation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That should work. Thanks as always. |
||
if mask is not None: | ||
if not isinstance(scale, list): | ||
scale = [scale] | ||
|
||
current_num_images = mask.shape[1] | ||
for i in range(current_num_images): | ||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) | ||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) | ||
|
||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||
_current_ip_hidden_states = F.scaled_dot_product_attention( | ||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | ||
) | ||
|
||
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( | ||
batch_size, -1, attn.heads * head_dim | ||
) | ||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) | ||
|
||
current_num_images = mask.shape[1] | ||
for i in range(current_num_images): | ||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) | ||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) | ||
mask_downsample = IPAdapterMaskProcessor.downsample( | ||
mask[:, i, :, :], | ||
batch_size, | ||
_current_ip_hidden_states.shape[1], | ||
_current_ip_hidden_states.shape[2], | ||
) | ||
|
||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) | ||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) | ||
else: | ||
ip_key = to_k_ip(current_ip_hidden_states) | ||
ip_value = to_v_ip(current_ip_hidden_states) | ||
|
||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||
_current_ip_hidden_states = F.scaled_dot_product_attention( | ||
current_ip_hidden_states = F.scaled_dot_product_attention( | ||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | ||
) | ||
|
||
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( | ||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( | ||
batch_size, -1, attn.heads * head_dim | ||
) | ||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) | ||
|
||
mask_downsample = IPAdapterMaskProcessor.downsample( | ||
mask[:, i, :, :], | ||
batch_size, | ||
_current_ip_hidden_states.shape[1], | ||
_current_ip_hidden_states.shape[2], | ||
) | ||
|
||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) | ||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) | ||
else: | ||
ip_key = to_k_ip(current_ip_hidden_states) | ||
ip_value = to_v_ip(current_ip_hidden_states) | ||
|
||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||
current_ip_hidden_states = F.scaled_dot_product_attention( | ||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | ||
) | ||
|
||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( | ||
batch_size, -1, attn.heads * head_dim | ||
) | ||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) | ||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) | ||
|
||
hidden_states = hidden_states + scale * current_ip_hidden_states | ||
hidden_states = hidden_states + scale * current_ip_hidden_states | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) | ||
|
Uh oh!
There was an error while loading. Please reload this page.