[Bugfix] fixes the decoding metadata of dense mla's fp8 kvcache. (#27144)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -19,7 +19,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
flashmla
|
flashmla
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||||
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
|
GIT_TAG 28417e516fcbf6257a422ba117ef5b6f44da5682
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
BUILD_COMMAND ""
|
BUILD_COMMAND ""
|
||||||
@@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS)
|
|||||||
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
||||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
||||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
set(FlashMLA_INCLUDES
|
set(FlashMLA_INCLUDES
|
||||||
|
|||||||
@@ -102,6 +102,12 @@ def get_mla_metadata(
|
|||||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||||
"""
|
"""
|
||||||
|
if is_fp8_kvcache and topk is None:
|
||||||
|
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
|
||||||
|
cache_seqlens,
|
||||||
|
num_q_tokens_per_head_k,
|
||||||
|
num_heads_k,
|
||||||
|
)
|
||||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||||
cache_seqlens,
|
cache_seqlens,
|
||||||
num_q_tokens_per_head_k,
|
num_q_tokens_per_head_k,
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
|
|
||||||
self.cg_buf_tile_scheduler_metadata = None
|
self.cg_buf_tile_scheduler_metadata = None
|
||||||
self.cg_buf_num_splits = None
|
self.cg_buf_num_splits = None
|
||||||
|
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||||
|
|
||||||
device_properties = torch.cuda.get_device_properties(self.device)
|
device_properties = torch.cuda.get_device_properties(self.device)
|
||||||
num_sms = device_properties.multi_processor_count
|
num_sms = device_properties.multi_processor_count
|
||||||
@@ -123,6 +124,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
seq_lens_device,
|
seq_lens_device,
|
||||||
self.num_q_heads,
|
self.num_q_heads,
|
||||||
1, # MQA for the decode path
|
1, # MQA for the decode path
|
||||||
|
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||||
|
|||||||
Reference in New Issue
Block a user