From 0aac2048bf3a7e60eaddf1ebcb4165ed777eb8ff Mon Sep 17 00:00:00 2001 From: haosdent Date: Fri, 27 Mar 2026 00:13:39 +0800 Subject: [PATCH] [Bugfix] Restore CUDA graph persistent buffers for FP8 FlashMLA decode (#35175) Signed-off-by: haosdent Co-authored-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashmla.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 76f32a54f..df54b865a 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -172,6 +172,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): num_q_tokens_per_head_k, 1, # MQA for the decode path ) + + # Copy FP8 metadata into persistent CUDA graph buffers + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None + n = tile_scheduler_metadata.size(0) + assert n <= self.cg_buf_tile_scheduler_metadata.size(0) + self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata) + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[:n] + + n = num_splits.size(0) + assert n <= self.cg_buf_num_splits.size(0) + self.cg_buf_num_splits[:n].copy_(num_splits) + num_splits = self.cg_buf_num_splits[:n] + scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata scheduler_metadata.num_splits = num_splits