[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)
This commit is contained in:
@@ -331,11 +331,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
calculate_kv_scales = False
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
@@ -347,6 +342,36 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
|
||||
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
||||
if (
|
||||
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
and kv_cache_dtype != "fp8_ds_mla"
|
||||
):
|
||||
assert cache_config is not None
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
logger.info_once(
|
||||
"Using DeepSeek's fp8_ds_mla KV cache format. To use standard "
|
||||
"fp8 kv-cache format, please set `--attention-backend "
|
||||
"FLASHINFER_MLA_SPARSE`"
|
||||
)
|
||||
|
||||
if (
|
||||
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
):
|
||||
logger.info_once(
|
||||
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
|
||||
"KV cache format, please set `--attention-backend FLASHMLA_SPARSE`"
|
||||
)
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
|
||||
Reference in New Issue
Block a user