Store pre-cached norm weights on self to prevent GC during graph replay — root cause of all-zeros replay bug

This commit is contained in:
2026-06-06 07:29:33 +00:00
parent dcb2495a5b
commit 5a98cc6d90

View File

@@ -283,6 +283,11 @@ class CUDAGraphDecoder:
if kvn is not None:
kv_norm_dev[li] = kvn.to(dev, torch.float32) if kvn.device != torch.device(dev) or kvn.dtype != torch.float32 else kvn
self.attn_norm_dev = attn_norm_dev
self.ffn_norm_dev = ffn_norm_dev
self.q_norm_dev = q_norm_dev
self.kv_norm_dev = kv_norm_dev
# Verify all MoE/SE buffers are allocated (swizzled buffers must exist before capture)
for li in range(self.n_layers):
moe = moe_runners.get(li)