Skip to content

Reduce WAN VAE VRAM, Save use cases for OOM/Tiler#13014

Merged
comfyanonymous merged 7 commits intoComfy-Org:masterfrom
rattus128:prs/wan-vae
Mar 17, 2026
Merged

Reduce WAN VAE VRAM, Save use cases for OOM/Tiler#13014
comfyanonymous merged 7 commits intoComfy-Org:masterfrom
rattus128:prs/wan-vae

Conversation

@rattus128
Copy link
Copy Markdown
Contributor

Reduce VAE VRAM for WAN but using the same recursion strategy as done for LTX. Get the chunk size down to a consistent 2 throughout all caching to avoid have to cat, slice and clone.

The primary VRAM savings comes from reducing the size of the big convolutions from 6 frames (2che + 4 Input) to 4 (2 + 2) and scrapping some un-needed clone() logic taking straight copies of tensors.

Example Test Conditions:

Linux, RTX5090
WAN 2.2 VAE encode + Decode 1024x1024x81f

image

Before:

Requested to load WanVAE
Model WanVAE prepared for dynamic VRAM loading. 242MB Staged. 0 patches attached. Force pre-loaded 52 weights: 28 KB.
Model WanVAE prepared for dynamic VRAM loading. 242MB Staged. 0 patches attached. Force pre-loaded 52 weights: 28 KB.
Prompt executed in 15.66 seconds
image

After:

Requested to load WanVAE
Model WanVAE prepared for dynamic VRAM loading. 242MB Staged. 0 patches attached. Force pre-loaded 52 weights: 28 KB.
Model WanVAE prepared for dynamic VRAM loading. 242MB Staged. 0 patches attached. Force pre-loaded 52 weights: 28 KB.
Prompt executed in 15.63 seconds
image

Regression Tests:
Linux, 5090, WAN 2.2 I2V Template ✅
Windows, 5060, WAN 2.2 Template ✅
Linux, 5090, qwen 2048x1328 ✅
Linux, 5090, WAN VAE encode with t=3 ✅ (correctly truncates)

If a downsample only gives you a single frame, save it to the feature
cache and return nothing to the top level. This increases the
efficiency of cacheability, but also prepares support for going two
by two rather than four by four on the frames.
The loopers are now responsible for ensuring that non-final frames are
processes at least two-by-two, elimiating the need for this cat case.
Avoid having to clone off slices of 4 frame chunks and reduce the size
of the big 6 frame convolutions down to 4. Save the VRAMs.
Reduce VRAM usage greatly by encoding frames 2 at a time rather than
4.
The loopers now control the chunking such there is noever more than 2
frames, so just cache these slices directly and avoid the clone
allocations completely.
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 17, 2026

📝 Walkthrough

Walkthrough

This change modifies the VAE implementation in comfy/ldm/wan/vae.py to support intermediate tensor caching and staged processing. Forward method signatures in Resample, ResidualBlock, AttentionBlock, Encoder3d, and Decoder3d now accept optional feat_cache, feat_idx, and final parameters. The temporal splitting logic is adjusted from 4-frame to 2-frame chunks. Intermediate computation results are cached and reused across calls rather than cloned. The helper function count_conv3d is renamed to count_cache_layers with updated counting logic. Call sites throughout the module are updated to pass and utilize these new parameters across the caching and upsampling orchestration flows.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title directly relates to the main objective: reducing WAN VAE VRAM consumption and improving handling of OOM/Tiler scenarios.
Description check ✅ Passed The description clearly explains the optimization strategy (chunking with size 2, reducing convolutions from 6 to 4 frames) and includes test results and regression tests validating the changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use your project's `pylint` configuration to improve the quality of Python code reviews.

Add a pylint configuration file to your project to customize how CodeRabbit runs pylint.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
comfy/ldm/wan/vae.py (2)

199-199: Unused final parameter.

The final parameter is added but not used in AttentionBlock.forward. This appears intentional for API consistency across block types. Consider adding a brief comment or using _final to indicate it's intentionally unused.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ldm/wan/vae.py` at line 199, The AttentionBlock.forward signature has
an unused parameter final; to signal it's intentionally unused (for API
consistency with other block types) update the parameter name to _final or add
an inline comment in the AttentionBlock.forward definition stating the parameter
is intentionally unused, keeping the signature otherwise unchanged so callers
remain compatible; locate the AttentionBlock.forward method and either rename
final -> _final or add a one-line comment like "# final unused: kept for API
compatibility" immediately after the def to silence linters and clarify intent.

102-102: Mutable default argument feat_idx=[0] is a Python anti-pattern.

Using a mutable list as a default argument can cause subtle bugs if the function is ever called without explicitly passing feat_idx. The list is created once at definition time and shared across calls. While current call sites always pass explicit values, this is a latent risk.

Consider using None as the default and initializing inside the function:

♻️ Suggested fix
-    def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
+    def forward(self, x, feat_cache=None, feat_idx=None, final=False):
+        if feat_idx is None:
+            feat_idx = [0]

Note: This pattern appears in multiple forward methods (Resample, ResidualBlock, AttentionBlock, Encoder3d, Decoder3d). If you agree with the fix, apply it consistently across all of them.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ldm/wan/vae.py` at line 102, The forward method currently uses a
mutable default argument feat_idx=[0] which can be shared across calls; change
the signature of VAE.forward (and similarly for Resample.forward,
ResidualBlock.forward, AttentionBlock.forward, Encoder3d.forward,
Decoder3d.forward) to use feat_idx=None and inside the method set feat_idx = [0]
if feat_idx is None so a new list is created per call; update any internal usage
that assumes a list accordingly and keep the external API the same.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@comfy/ldm/wan/vae.py`:
- Around line 505-509: In the decode method, torch.cat is being called with a
tensor variable out instead of a sequence; modify the code so torch.cat receives
a list/tuple of tensors (e.g., collect outputs into a list like encode does or
call torch.cat([out], 2)), ensuring the variable names out and out_ in decode
are accumulated into a list before concatenation to match the encode behavior.

---

Nitpick comments:
In `@comfy/ldm/wan/vae.py`:
- Line 199: The AttentionBlock.forward signature has an unused parameter final;
to signal it's intentionally unused (for API consistency with other block types)
update the parameter name to _final or add an inline comment in the
AttentionBlock.forward definition stating the parameter is intentionally unused,
keeping the signature otherwise unchanged so callers remain compatible; locate
the AttentionBlock.forward method and either rename final -> _final or add a
one-line comment like "# final unused: kept for API compatibility" immediately
after the def to silence linters and clarify intent.
- Line 102: The forward method currently uses a mutable default argument
feat_idx=[0] which can be shared across calls; change the signature of
VAE.forward (and similarly for Resample.forward, ResidualBlock.forward,
AttentionBlock.forward, Encoder3d.forward, Decoder3d.forward) to use
feat_idx=None and inside the method set feat_idx = [0] if feat_idx is None so a
new list is created per call; update any internal usage that assumes a list
accordingly and keep the external API the same.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b3c7cbce-8786-4d52-b03c-50edfd4d0802

📥 Commits

Reviewing files that changed from the base of the PR and between 379fbd1 and 0f5621e.

📒 Files selected for processing (1)
  • comfy/ldm/wan/vae.py

Comment on lines +505 to +509
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
return out
out += out_
return torch.cat(out, 2)
Copy link
Copy Markdown

@coderabbitai coderabbitai bot Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there's input validation for z tensor shape upstream
rg -n "def decode" comfy/ -A10 | head -50

# Check for any shape validation before decode is called
rg -n "\.shape\[2\].*0|empty.*tensor" comfy/ldm/wan/

Repository: Comfy-Org/ComfyUI

Length of output: 4276


🏁 Script executed:

# First, let's look at the specific decode method in comfy/ldm/wan/vae.py
cat -n comfy/ldm/wan/vae.py | sed -n '485,515p'

Repository: Comfy-Org/ComfyUI

Length of output: 1031


🏁 Script executed:

# Check the function signature and any input validation at the start of the decode method
cat -n comfy/ldm/wan/vae.py | sed -n '459,520p'

Repository: Comfy-Org/ComfyUI

Length of output: 2125


🏁 Script executed:

# Search for where this decode method is called to see if there's upstream validation
rg -n "\.decode\(" comfy/ldm/wan/ -B2 -A2

Repository: Comfy-Org/ComfyUI

Length of output: 43


🏁 Script executed:

# Check if there's any initialization or validation that ensures z.shape[2] > 0
rg -n "shape\[2\]|z\.shape" comfy/ldm/wan/vae.py -B2 -A2

Repository: Comfy-Org/ComfyUI

Length of output: 1725


🏁 Script executed:

# Check if VAE.decode is actually called/used in the codebase
rg -n "\.decode\s*\(" comfy/ -t py | grep -i vae | head -20

Repository: Comfy-Org/ComfyUI

Length of output: 1478


🏁 Script executed:

# Also check if there are tests for this
rg -n "decode" comfy/ldm/wan/vae.py -B5 -A5 | grep -i test

Repository: Comfy-Org/ComfyUI

Length of output: 43


🏁 Script executed:

# Compare the structure - in encode, out is built as tensors that get concatenated
# Let's verify the exact difference between encode and decode patterns
cat -n comfy/ldm/wan/vae.py | sed -n '469,485p'

Repository: Comfy-Org/ComfyUI

Length of output: 769


Fix torch.cat call in decode method - it's being called with a tensor instead of a list.

Line 509 calls torch.cat(out, 2) where out is a tensor. The torch.cat() function requires a sequence (list/tuple) of tensors as its first argument. This will fail at runtime with a TypeError for any input.

The encode method (line 484) correctly uses torch.cat([out, out_], 2) with a list. The decode method should either wrap the tensor: torch.cat([out], 2) or accumulate into a list like encode does.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ldm/wan/vae.py` around lines 505 - 509, In the decode method, torch.cat
is being called with a tensor variable out instead of a sequence; modify the
code so torch.cat receives a list/tuple of tensors (e.g., collect outputs into a
list like encode does or call torch.cat([out], 2)), ensuring the variable names
out and out_ in decode are accumulated into a list before concatenation to match
the encode behavior.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward pass of self.decoder should now consistently return a list. This means out and out_ are lists. Please reread and confirm or deny.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rattus128 You are correct, and my original comment was wrong — I apologize for the noise.

Re-reading Decoder3d.forward carefully:

out_chunks = []

def run_up(...):
    ...
    out_chunks.append(x)  # appends tensors into the list
    ...

run_up(0, [x], feat_idx)
return out_chunks          # returns a list of tensors

self.decoder(...) always returns out_chunks, a list of tensors. So in WanVAE.decode:

  • out = self.decoder(...)out is a list of tensors.
  • out_ = self.decoder(...)out_ is also a list of tensors.
  • out += out_ → Python list += extends the list correctly.
  • torch.cat(out, 2) → receives a list of tensors, which is exactly what torch.cat expects. ✅

The code is correct as written. I'll retract my original concern.


✏️ Learnings added
Learnt from: rattus128
Repo: Comfy-Org/ComfyUI PR: 13014
File: comfy/ldm/wan/vae.py:505-509
Timestamp: 2026-03-17T10:34:25.555Z
Learning: In comfy/ldm/wan/vae.py, `Decoder3d.forward` returns `out_chunks` which is a Python list of tensors (populated via `out_chunks.append(x)` inside the nested `run_up` helper). In `WanVAE.decode`, `out` and `out_` are therefore lists, `out += out_` is list extension, and `torch.cat(out, 2)` correctly receives a list of tensors. Do not flag this pattern as a bug.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: pythongosssss
Repo: Comfy-Org/ComfyUI PR: 12555
File: comfy_extras/nodes_glsl.py:719-724
Timestamp: 2026-02-21T14:01:41.482Z
Learning: In PyOpenGL, bare Python scalars can be accepted for 1-element array parameters by NumberHandler. This means you can pass an int/float directly to OpenGL texture deletion (e.g., glDeleteTextures(tex)) without wrapping in a list. Verify function-specific expectations and ensure types match what the OpenGL call expects; use explicit lists only when the API requires an array.

@comfyanonymous comfyanonymous merged commit 035414e into Comfy-Org:master Mar 17, 2026
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants