[Bugfix] Restore CUDA graph persistent buffers for FP8 FlashMLA decode (#35175)
Signed-off-by: haosdent <haosdent@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -172,6 +172,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
# Copy FP8 metadata into persistent CUDA graph buffers
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
n = tile_scheduler_metadata.size(0)
|
||||
assert n <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[:n]
|
||||
|
||||
n = num_splits.size(0)
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
self.cg_buf_num_splits[:n].copy_(num_splits)
|
||||
num_splits = self.cg_buf_num_splits[:n]
|
||||
|
||||
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||
scheduler_metadata.num_splits = num_splits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user