Description
I tried to generate samples in Colab and everything works except that I had to change this line of code in /cm/unet.py, clearing out factory_kwargs.
Not sure if this is a bug or I did something wrong. This is how I ran it: https://github.com/JonathanFly/consistency_models_colab_notebook/blob/main/Consistency_Models_Make_Samples.ipynb
class QKVFlashAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
batch_first=True,
attention_dropout=0.0,
causal=False,
device=None,
dtype=None,
**kwargs,
) -> None:
from einops import rearrange
from flash_attn.flash_attention import FlashAttention
assert batch_first
#factory_kwargs = {"device": device, "dtype": dtype}
factory_kwargs = {}
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.causal = causal
Metadata
Metadata
Assignees
Labels
No labels