diff --git a/single_shot_inference.py b/single_shot_inference.py index a2718ecd..7d7dee61 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -188,7 +188,7 @@ class CUDAGraphDecoder: def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms, kv_caches, compressors, indexers, moe_runners, se_runners, routers, prod_lins, layer_w, rope_caches, hc_head, - final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf, comp_rope_caches=None): + final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=None): """Capture CUDA graphs for all layers + lm_head. Must be called after one warmup step so that: @@ -211,7 +211,7 @@ class CUDAGraphDecoder: self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu], attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], dec_pos_buf, dec_tid32_buf, + kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu], compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li), @@ -1785,7 +1785,7 @@ def main(): cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms, kv_caches, compressors, indexers, moe_runners, se_runners, routers, prod_lins, layer_w, rope_caches, hc_head, - final_norm_w, lm_w, dec_pos_buf, dec_tid32_buf, + final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=comp_rope_caches, ) print(f" CUDA graphs captured. Graph replay starts on step 1.", flush=True)