diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 76f32a54f..df54b865a 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -172,6 +172,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): num_q_tokens_per_head_k, 1, # MQA for the decode path ) + + # Copy FP8 metadata into persistent CUDA graph buffers + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None + n = tile_scheduler_metadata.size(0) + assert n <= self.cg_buf_tile_scheduler_metadata.size(0) + self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata) + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[:n] + + n = num_splits.size(0) + assert n <= self.cg_buf_num_splits.size(0) + self.cg_buf_num_splits[:n].copy_(num_splits) + num_splits = self.cg_buf_num_splits[:n] + scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata scheduler_metadata.num_splits = num_splits