CUDA graph: Use per-GPU position/token buffers for graph capture

Cross-GPU .to() calls inside graph capture cause 'dependency on uncaptured
work in another stream'. Fix: pass dec_pos_per_gpu/dec_tid32_per_gpu to
capture() so each layer's graph uses buffers on its own GPU.
This commit is contained in:
2026-06-03 22:56:20 +00:00
parent f57de06eb5
commit 56b816a54f

View File

@@ -188,7 +188,7 @@ class CUDAGraphDecoder:
def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
kv_caches, compressors, indexers, moe_runners, se_runners,
routers, prod_lins, layer_w, rope_caches, hc_head,
final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf, comp_rope_caches=None):
final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=None):
"""Capture CUDA graphs for all layers + lm_head.
Must be called after one warmup step so that:
@@ -211,7 +211,7 @@ class CUDAGraphDecoder:
self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li),
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], dec_pos_buf, dec_tid32_buf,
kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu],
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li),
@@ -1785,7 +1785,7 @@ def main():
cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
kv_caches, compressors, indexers, moe_runners, se_runners,
routers, prod_lins, layer_w, rope_caches, hc_head,
final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf,
final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu,
comp_rope_caches=comp_rope_caches,
)
print(f" CUDA graphs captured. Graph replay starts on step 1.", flush=True)