From 56b816a54f44526e79e2dbfa449a53bb59b6b98f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 22:56:20 +0000 Subject: [PATCH] CUDA graph: Use per-GPU position/token buffers for graph capture Cross-GPU .to() calls inside graph capture cause 'dependency on uncaptured work in another stream'. Fix: pass dec_pos_per_gpu/dec_tid32_per_gpu to capture() so each layer's graph uses buffers on its own GPU. --- single_shot_inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index a2718ecd..7d7dee61 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -188,7 +188,7 @@ class CUDAGraphDecoder: def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms, kv_caches, compressors, indexers, moe_runners, se_runners, routers, prod_lins, layer_w, rope_caches, hc_head, - final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf, comp_rope_caches=None): + final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=None): """Capture CUDA graphs for all layers + lm_head. Must be called after one warmup step so that: @@ -211,7 +211,7 @@ class CUDAGraphDecoder: self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu], attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], dec_pos_buf, dec_tid32_buf, + kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu], compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li), @@ -1785,7 +1785,7 @@ def main(): cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms, kv_caches, compressors, indexers, moe_runners, se_runners, routers, prod_lins, layer_w, rope_caches, hc_head, - final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf, + final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=comp_rope_caches, ) print(f" CUDA graphs captured. Graph replay starts on step 1.", flush=True)