[Attention] MLA decode optimizations (#12528)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
@@ -736,17 +736,25 @@ class ModelConfig:
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_text_config.hidden_size
|
||||
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
# TODO add deepseek_v3
|
||||
return hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2'))
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2', 'deepseek_v3')):
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||
0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
|
||||
0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
if self.is_deepseek_mla:
|
||||
if self.use_mla:
|
||||
return self.hf_text_config.kv_lora_rank
|
||||
else:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_rope_head_dim", 0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_nope_head_dim", 0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
@@ -805,6 +813,10 @@ class ModelConfig:
|
||||
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
if self.use_mla:
|
||||
# When using MLA during decode it becomes MQA
|
||||
return 1
|
||||
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
@@ -955,6 +967,11 @@ class ModelConfig:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_cross_encoder_model(architectures)
|
||||
|
||||
@property
|
||||
def use_mla(self) -> bool:
|
||||
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
|
||||
return use_mla
|
||||
|
||||
@property
|
||||
def supported_runner_types(self) -> Set[RunnerType]:
|
||||
return {_TASK_RUNNER[task] for task in self.supported_tasks}
|
||||
|
||||
Reference in New Issue
Block a user