[Attention] Deepseek v3 MLA support with FP8 compute (#12601)

This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights 

---------

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: simon-mo <simon.mo@hey.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
This commit is contained in:
Lucas Wilkinson
2025-02-01 00:52:51 -05:00
committed by GitHub
parent 3e1c76cf3a
commit baeded2569
10 changed files with 579 additions and 84 deletions

View File

@@ -739,18 +739,19 @@ class ModelConfig:
@property
def is_deepseek_mla(self) -> bool:
# TODO add deepseek_v3
return hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
in ('deepseek_v2'))
return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\
and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
0)
if self.use_mla:
return self.hf_text_config.kv_lora_rank
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
else:
qk_rope_head_dim = getattr(self.hf_text_config,
"qk_rope_head_dim", 0)
qk_nope_head_dim = getattr(self.hf_text_config,
"qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim:
@@ -969,6 +970,32 @@ class ModelConfig:
@property
def use_mla(self) -> bool:
if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
return False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if self.quantization == "compressed-tensors":
quant_config = self._parse_quant_hf_config()
for group_name, cfg in quant_config.get("config_groups",
("", {})).items():
act_cfg = cfg.get("input_activations", {})
act_type = None if act_cfg is None else act_cfg.get("type", "")
w_cfg = cfg.get("weights", {})
w_type = None if w_cfg is None else w_cfg.get("type", "")
if act_type != "fp8" or w_type != "fp8":
logger.warning(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.\n "
"Full config: %s", group_name, act_type, w_type,
quant_config)
return False
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
return use_mla