diff --git a/single_shot_inference.py b/single_shot_inference.py index 2cc733e8..efe9fe06 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -283,6 +283,11 @@ class CUDAGraphDecoder: if kvn is not None: kv_norm_dev[li] = kvn.to(dev, torch.float32) if kvn.device != torch.device(dev) or kvn.dtype != torch.float32 else kvn + self.attn_norm_dev = attn_norm_dev + self.ffn_norm_dev = ffn_norm_dev + self.q_norm_dev = q_norm_dev + self.kv_norm_dev = kv_norm_dev + # Verify all MoE/SE buffers are allocated (swizzled buffers must exist before capture) for li in range(self.n_layers): moe = moe_runners.get(li)