Skip to content

Add PT compileable support for flash_attn_with_kvcache #1592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jataylo
Copy link

@jataylo jataylo commented Apr 14, 2025

Continues #1139 adding custom op for flash_attn_with_kvcache.

On a transformers model this improves perf by >2x by avoiding graph breaks. There is a gotcha here, with this implementation an error is thrown in PyTorch 2.6 in user code when reshaping FA output:

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7f10e00494e0; to 'torch.storage.UntypedStorage' at 0x7f10e0049400>

This is not an issue for PyTorch 2.7, so I had to introduce conditionalisation to workaround this by returning clone of the output tensors only for PT versions earlier than 2.7 and when compile is being used.

@jataylo
Copy link
Author

jataylo commented Apr 16, 2025

@tridao alternatively if preferred, instead of conditionalising the clone for < PT 2.7, we could just disable compile-able support for this op if below 2.7, the additional clone could cause regressions and increase memory usage.

@tridao
Copy link
Member

tridao commented Apr 22, 2025

We will drop support for pytorch < 2.4 so you can simplify the code.
I'll need to think more about the clone. Does it slow things down when running in eager?

@jataylo
Copy link
Author

jataylo commented May 9, 2025

@tridao sorry for slow response here, I've been away.

I imagine the additional clone could add an eager mode overhead as well as potentially introduce OOM issues, which have been observed in some cases.

Perhaps we just need to lock down the compileable support for this method until 2.7. Unless @drisspg has any thoughts here on why we see this weakref error before 2.7

@zou3519
Copy link

zou3519 commented May 13, 2025

@jataylo do you have the full stack trace?

@jataylo
Copy link
Author

jataylo commented May 14, 2025

@zou3519
https://gist.github.com/jataylo/ef8b729d53a4415bc00c00c03e934950

Gist of stack trace here, let me see if I can get a reproducer as it's currently from a full model code, looks like we hit this when applying reshape onto the output of the flash_attn kvcache call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants