CUDA graph: Use per-GPU position/token buffers for graph capture
Cross-GPU .to() calls inside graph capture cause 'dependency on uncaptured work in another stream'. Fix: pass dec_pos_per_gpu/dec_tid32_per_gpu to capture() so each layer's graph uses buffers on its own GPU.
This commit is contained in:
@@ -188,7 +188,7 @@ class CUDAGraphDecoder:
|
||||
def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
|
||||
kv_caches, compressors, indexers, moe_runners, se_runners,
|
||||
routers, prod_lins, layer_w, rope_caches, hc_head,
|
||||
final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf, comp_rope_caches=None):
|
||||
final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=None):
|
||||
"""Capture CUDA graphs for all layers + lm_head.
|
||||
|
||||
Must be called after one warmup step so that:
|
||||
@@ -211,7 +211,7 @@ class CUDAGraphDecoder:
|
||||
self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||||
attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], dec_pos_buf, dec_tid32_buf,
|
||||
kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu],
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li),
|
||||
@@ -1785,7 +1785,7 @@ def main():
|
||||
cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
|
||||
kv_caches, compressors, indexers, moe_runners, se_runners,
|
||||
routers, prod_lins, layer_w, rope_caches, hc_head,
|
||||
final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf,
|
||||
final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu,
|
||||
comp_rope_caches=comp_rope_caches,
|
||||
)
|
||||
print(f" CUDA graphs captured. Graph replay starts on step 1.", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user