[Bugfix] Restore CUDA graph persistent buffers for FP8 FlashMLA decode (#35175)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
haosdent
2026-03-27 00:13:39 +08:00
committed by GitHub
parent cb2263218e
commit 0aac2048bf

View File

@@ -172,6 +172,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
num_q_tokens_per_head_k,
1, # MQA for the decode path
)
# Copy FP8 metadata into persistent CUDA graph buffers
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
n = tile_scheduler_metadata.size(0)
assert n <= self.cg_buf_tile_scheduler_metadata.size(0)
self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[:n]
n = num_splits.size(0)
assert n <= self.cg_buf_num_splits.size(0)
self.cg_buf_num_splits[:n].copy_(num_splits)
num_splits = self.cg_buf_num_splits[:n]
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits