fix: Hunyuan3D 2.1 batch size crashes in attention and forward pass#13699
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
📝 WalkthroughWalkthroughThis pull request corrects batch-dimension handling in attention layers and adds guards for classifier-free guidance operations. 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 6/8 reviews remaining, refill in 9 minutes and 59 seconds.Comment |
- CrossAttention.forward: hardcoded `1` in kv.view() replaced with actual batch size `b` - Attention.forward: hardcoded `1` in qkv_combined.view() replaced with actual batch size `B` - HunYuanDiTPlain.forward: context.chunk(2) and output.chunk(2) now guarded with shape[0] >= 2 check to avoid crash when running without negative conditioning Fixes Comfy-Org#10142
cc98e29 to
43b0dab
Compare
|
Can confirm that this fixes issues I was having with running the Hunyuan3D 2.1 template. Went from errors or invalid results to correct behavior, using MPS on a Mac Studio. |
|
Thanks @Kivylius do you have a way to reproduce the issue before fixing it? It would be helpful for me to test. |
|
@alexisrolland very mutch like @Alanaktion
More details in #10142 Im the Macbook Pro M2 |
That's what I tried before posting my previous comment, the template worked fine for me ;) |
Im not sure about you configuration, but for me, it was just fresh install, no plugins, no extesnions, not even changed any varables at all in the template. I'm not sure what different from your to mine, but i recomend to try it on virtual machene from scatch if possible. I'm on MacOS 26.4.1 and Python 3.13.5 personaly but not sure if that makes mutch of difference. In that thread it seems like several other people are having the same issue, but im not sure of there setups or versions, plugins ect, but all seem to be the exact same error. One interesting observation;, sometimes its get further then other attemps, to me that screams of invalid outputs that not handled properly, or more so that that specific erroneous output is more prominent is specific setups over others else this would have already been fixed. If there any other logs that are missing that not in that thread already, let me know and ill try my best to provide it here. |
|
I tested with batch size 1, 2, 3, and different resolutions 1024, 2048. I could not reproduce the issue either before or after the fix. Since multiple people have reported this fixes it for them, I am merging it. |
The previous gate (len(cond_or_uncond) == 2 and set == {0, 1}) was
intended to skip the cond/uncond swap when only one half was present
under MultiGPU CFG Split, but it was too restrictive: it also skipped
batch_size > 1 + CFG (cond_or_uncond like [0, 0, 1, 1] or [0,0,0,0,
1,1,1,1]), where chunk(2) still splits the batch cleanly into a cond
half and an uncond half and the swap is still required.
Switch to context.shape[0] >= 2, matching the parallel fix landed on
master in #13699. The swap is a permutation-invariant no-op when the
two halves don't form a CFG pair (since the output swap_cfg_halves
block immediately undoes the permutation), so the only thing the gate
actually needs to do is guard against chunk(2) on a batch of one.
Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
CrossAttention.kv.view and Attention.qkv_combined.view both hardcoded batch=1 in the reshape, crashing or silently mis-shaping whenever the actual batch dimension was greater than 1. These were fixed on master in #13699 as part of the same patch that gated the chunk(2) swap, but worksplit-multigpu only picked up the chunk(2) gate. Bring the two view() fixes over so we have parity with master. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp <amp@ampcode.com>
Brings in 18 commits from master so worksplit-multigpu does not regress fixes that landed on main since the last sync: - #13699 Hunyuan 3D 2.1 batch-size fixes (overlap with our own backport; conflict resolved in favor of the shape>=2 gate that binds swap_cfg_halves once and reuses it for the output swap-back) - #14031 ModelPatcherDynamic lora reshape / backup restore fix - #13802 Multi-threaded model load (memory_management / pinned_memory / model_management / aimdo plumbing) - #12679 lanczos single-channel tensor fix - #14010 Stable Audio 3 support - assorted partner-node, openapi, workflow-template, and tooling updates Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp <amp@ampcode.com>
1in kv.view() replaced with actual batch sizeb1in qkv_combined.view() replaced with actual batch sizeBFixes #10142