Reduce WAN VAE VRAM, Save use cases for OOM/Tiler#13014
Reduce WAN VAE VRAM, Save use cases for OOM/Tiler#13014comfyanonymous merged 7 commits intoComfy-Org:masterfrom
Conversation
If a downsample only gives you a single frame, save it to the feature cache and return nothing to the top level. This increases the efficiency of cacheability, but also prepares support for going two by two rather than four by four on the frames.
The loopers are now responsible for ensuring that non-final frames are processes at least two-by-two, elimiating the need for this cat case.
Avoid having to clone off slices of 4 frame chunks and reduce the size of the big 6 frame convolutions down to 4. Save the VRAMs.
Reduce VRAM usage greatly by encoding frames 2 at a time rather than 4.
The loopers now control the chunking such there is noever more than 2 frames, so just cache these slices directly and avoid the clone allocations completely.
📝 WalkthroughWalkthroughThis change modifies the VAE implementation in 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use your project's `pylint` configuration to improve the quality of Python code reviews.Add a pylint configuration file to your project to customize how CodeRabbit runs |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
comfy/ldm/wan/vae.py (2)
199-199: Unusedfinalparameter.The
finalparameter is added but not used inAttentionBlock.forward. This appears intentional for API consistency across block types. Consider adding a brief comment or using_finalto indicate it's intentionally unused.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/ldm/wan/vae.py` at line 199, The AttentionBlock.forward signature has an unused parameter final; to signal it's intentionally unused (for API consistency with other block types) update the parameter name to _final or add an inline comment in the AttentionBlock.forward definition stating the parameter is intentionally unused, keeping the signature otherwise unchanged so callers remain compatible; locate the AttentionBlock.forward method and either rename final -> _final or add a one-line comment like "# final unused: kept for API compatibility" immediately after the def to silence linters and clarify intent.
102-102: Mutable default argumentfeat_idx=[0]is a Python anti-pattern.Using a mutable list as a default argument can cause subtle bugs if the function is ever called without explicitly passing
feat_idx. The list is created once at definition time and shared across calls. While current call sites always pass explicit values, this is a latent risk.Consider using
Noneas the default and initializing inside the function:♻️ Suggested fix
- def forward(self, x, feat_cache=None, feat_idx=[0], final=False): + def forward(self, x, feat_cache=None, feat_idx=None, final=False): + if feat_idx is None: + feat_idx = [0]Note: This pattern appears in multiple forward methods (Resample, ResidualBlock, AttentionBlock, Encoder3d, Decoder3d). If you agree with the fix, apply it consistently across all of them.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/ldm/wan/vae.py` at line 102, The forward method currently uses a mutable default argument feat_idx=[0] which can be shared across calls; change the signature of VAE.forward (and similarly for Resample.forward, ResidualBlock.forward, AttentionBlock.forward, Encoder3d.forward, Decoder3d.forward) to use feat_idx=None and inside the method set feat_idx = [0] if feat_idx is None so a new list is created per call; update any internal usage that assumes a list accordingly and keep the external API the same.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/ldm/wan/vae.py`:
- Around line 505-509: In the decode method, torch.cat is being called with a
tensor variable out instead of a sequence; modify the code so torch.cat receives
a list/tuple of tensors (e.g., collect outputs into a list like encode does or
call torch.cat([out], 2)), ensuring the variable names out and out_ in decode
are accumulated into a list before concatenation to match the encode behavior.
---
Nitpick comments:
In `@comfy/ldm/wan/vae.py`:
- Line 199: The AttentionBlock.forward signature has an unused parameter final;
to signal it's intentionally unused (for API consistency with other block types)
update the parameter name to _final or add an inline comment in the
AttentionBlock.forward definition stating the parameter is intentionally unused,
keeping the signature otherwise unchanged so callers remain compatible; locate
the AttentionBlock.forward method and either rename final -> _final or add a
one-line comment like "# final unused: kept for API compatibility" immediately
after the def to silence linters and clarify intent.
- Line 102: The forward method currently uses a mutable default argument
feat_idx=[0] which can be shared across calls; change the signature of
VAE.forward (and similarly for Resample.forward, ResidualBlock.forward,
AttentionBlock.forward, Encoder3d.forward, Decoder3d.forward) to use
feat_idx=None and inside the method set feat_idx = [0] if feat_idx is None so a
new list is created per call; update any internal usage that assumes a list
accordingly and keep the external API the same.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b3c7cbce-8786-4d52-b03c-50edfd4d0802
📒 Files selected for processing (1)
comfy/ldm/wan/vae.py
| x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], | ||
| feat_cache=feat_map, | ||
| feat_idx=conv_idx) | ||
| out = torch.cat([out, out_], 2) | ||
| return out | ||
| out += out_ | ||
| return torch.cat(out, 2) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if there's input validation for z tensor shape upstream
rg -n "def decode" comfy/ -A10 | head -50
# Check for any shape validation before decode is called
rg -n "\.shape\[2\].*0|empty.*tensor" comfy/ldm/wan/Repository: Comfy-Org/ComfyUI
Length of output: 4276
🏁 Script executed:
# First, let's look at the specific decode method in comfy/ldm/wan/vae.py
cat -n comfy/ldm/wan/vae.py | sed -n '485,515p'Repository: Comfy-Org/ComfyUI
Length of output: 1031
🏁 Script executed:
# Check the function signature and any input validation at the start of the decode method
cat -n comfy/ldm/wan/vae.py | sed -n '459,520p'Repository: Comfy-Org/ComfyUI
Length of output: 2125
🏁 Script executed:
# Search for where this decode method is called to see if there's upstream validation
rg -n "\.decode\(" comfy/ldm/wan/ -B2 -A2Repository: Comfy-Org/ComfyUI
Length of output: 43
🏁 Script executed:
# Check if there's any initialization or validation that ensures z.shape[2] > 0
rg -n "shape\[2\]|z\.shape" comfy/ldm/wan/vae.py -B2 -A2Repository: Comfy-Org/ComfyUI
Length of output: 1725
🏁 Script executed:
# Check if VAE.decode is actually called/used in the codebase
rg -n "\.decode\s*\(" comfy/ -t py | grep -i vae | head -20Repository: Comfy-Org/ComfyUI
Length of output: 1478
🏁 Script executed:
# Also check if there are tests for this
rg -n "decode" comfy/ldm/wan/vae.py -B5 -A5 | grep -i testRepository: Comfy-Org/ComfyUI
Length of output: 43
🏁 Script executed:
# Compare the structure - in encode, out is built as tensors that get concatenated
# Let's verify the exact difference between encode and decode patterns
cat -n comfy/ldm/wan/vae.py | sed -n '469,485p'Repository: Comfy-Org/ComfyUI
Length of output: 769
Fix torch.cat call in decode method - it's being called with a tensor instead of a list.
Line 509 calls torch.cat(out, 2) where out is a tensor. The torch.cat() function requires a sequence (list/tuple) of tensors as its first argument. This will fail at runtime with a TypeError for any input.
The encode method (line 484) correctly uses torch.cat([out, out_], 2) with a list. The decode method should either wrap the tensor: torch.cat([out], 2) or accumulate into a list like encode does.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@comfy/ldm/wan/vae.py` around lines 505 - 509, In the decode method, torch.cat
is being called with a tensor variable out instead of a sequence; modify the
code so torch.cat receives a list/tuple of tensors (e.g., collect outputs into a
list like encode does or call torch.cat([out], 2)), ensuring the variable names
out and out_ in decode are accumulated into a list before concatenation to match
the encode behavior.
There was a problem hiding this comment.
The forward pass of self.decoder should now consistently return a list. This means out and out_ are lists. Please reread and confirm or deny.
There was a problem hiding this comment.
@rattus128 You are correct, and my original comment was wrong — I apologize for the noise.
Re-reading Decoder3d.forward carefully:
out_chunks = []
def run_up(...):
...
out_chunks.append(x) # appends tensors into the list
...
run_up(0, [x], feat_idx)
return out_chunks # returns a list of tensorsself.decoder(...) always returns out_chunks, a list of tensors. So in WanVAE.decode:
out = self.decoder(...)→outis a list of tensors.out_ = self.decoder(...)→out_is also a list of tensors.out += out_→ Python list+=extends the list correctly.torch.cat(out, 2)→ receives a list of tensors, which is exactly whattorch.catexpects. ✅
The code is correct as written. I'll retract my original concern.
✏️ Learnings added
Learnt from: rattus128
Repo: Comfy-Org/ComfyUI PR: 13014
File: comfy/ldm/wan/vae.py:505-509
Timestamp: 2026-03-17T10:34:25.555Z
Learning: In comfy/ldm/wan/vae.py, `Decoder3d.forward` returns `out_chunks` which is a Python list of tensors (populated via `out_chunks.append(x)` inside the nested `run_up` helper). In `WanVAE.decode`, `out` and `out_` are therefore lists, `out += out_` is list extension, and `torch.cat(out, 2)` correctly receives a list of tensors. Do not flag this pattern as a bug.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: pythongosssss
Repo: Comfy-Org/ComfyUI PR: 12555
File: comfy_extras/nodes_glsl.py:719-724
Timestamp: 2026-02-21T14:01:41.482Z
Learning: In PyOpenGL, bare Python scalars can be accepted for 1-element array parameters by NumberHandler. This means you can pass an int/float directly to OpenGL texture deletion (e.g., glDeleteTextures(tex)) without wrapping in a list. Verify function-specific expectations and ensure types match what the OpenGL call expects; use explicit lists only when the API requires an array.
Reduce VAE VRAM for WAN but using the same recursion strategy as done for LTX. Get the chunk size down to a consistent 2 throughout all caching to avoid have to cat, slice and clone.
The primary VRAM savings comes from reducing the size of the big convolutions from 6 frames (2che + 4 Input) to 4 (2 + 2) and scrapping some un-needed clone() logic taking straight copies of tensors.
Example Test Conditions:
Linux, RTX5090
WAN 2.2 VAE encode + Decode 1024x1024x81f
Before:
After:
Regression Tests:
Linux, 5090, WAN 2.2 I2V Template ✅
Windows, 5060, WAN 2.2 Template ✅
Linux, 5090, qwen 2048x1328 ✅
Linux, 5090, WAN VAE encode with t=3 ✅ (correctly truncates)