Store pre-cached norm weights on self to prevent GC during graph replay — root cause of all-zeros replay bug
This commit is contained in:
@@ -283,6 +283,11 @@ class CUDAGraphDecoder:
|
||||
if kvn is not None:
|
||||
kv_norm_dev[li] = kvn.to(dev, torch.float32) if kvn.device != torch.device(dev) or kvn.dtype != torch.float32 else kvn
|
||||
|
||||
self.attn_norm_dev = attn_norm_dev
|
||||
self.ffn_norm_dev = ffn_norm_dev
|
||||
self.q_norm_dev = q_norm_dev
|
||||
self.kv_norm_dev = kv_norm_dev
|
||||
|
||||
# Verify all MoE/SE buffers are allocated (swizzled buffers must exist before capture)
|
||||
for li in range(self.n_layers):
|
||||
moe = moe_runners.get(li)
|
||||
|
||||
Reference in New Issue
Block a user