[Attention] Deepseek v3 MLA support with FP8 compute (#12601)
This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
This commit is contained in:
@@ -739,18 +739,19 @@ class ModelConfig:
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
# TODO add deepseek_v3
|
||||
return hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2'))
|
||||
return (hasattr(self.hf_text_config, "model_type")) \
|
||||
and (self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3'))\
|
||||
and (self.hf_text_config.kv_lora_rank is not None)
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if self.is_deepseek_mla:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||
0)
|
||||
if self.use_mla:
|
||||
return self.hf_text_config.kv_lora_rank
|
||||
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
|
||||
else:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_rope_head_dim", 0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_nope_head_dim", 0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
@@ -969,6 +970,32 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def use_mla(self) -> bool:
|
||||
if self.quantization is not None and self.quantization not in [\
|
||||
"fp8", "compressed-tensors"]:
|
||||
logger.warning(
|
||||
"MLA is not supported with %s quantization. "
|
||||
"Disabling MLA.", self.quantization)
|
||||
return False
|
||||
|
||||
# If using a "compressed-tensors" checkpoint, check that all groups
|
||||
# have fp8 for both weights and activations.
|
||||
if self.quantization == "compressed-tensors":
|
||||
quant_config = self._parse_quant_hf_config()
|
||||
for group_name, cfg in quant_config.get("config_groups",
|
||||
("", {})).items():
|
||||
act_cfg = cfg.get("input_activations", {})
|
||||
act_type = None if act_cfg is None else act_cfg.get("type", "")
|
||||
w_cfg = cfg.get("weights", {})
|
||||
w_type = None if w_cfg is None else w_cfg.get("type", "")
|
||||
if act_type != "fp8" or w_type != "fp8":
|
||||
logger.warning(
|
||||
"compressed-tensors MLA support requires fp8 "
|
||||
"activations and weights in group '%s', but got "
|
||||
"activations type '%s' and weights type '%s'.\n "
|
||||
"Full config: %s", group_name, act_type, w_type,
|
||||
quant_config)
|
||||
return False
|
||||
|
||||
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
|
||||
return use_mla
|
||||
|
||||
|
||||
Reference in New Issue
Block a user