[Feat][Bugfix] Enable additional dimension for Flashinfer MLA and fix routing dtype (#36931)
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
6d53efd2a5
commit
367cf5cd3e
@@ -47,7 +47,11 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.attention import Attention
|
from vllm.model_executor.layers.attention import Attention
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
|
GateLinear,
|
||||||
|
RoutingMethodType,
|
||||||
|
SharedFusedMoE,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
@@ -333,8 +337,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
# NOTE(rob): this is a hack until we finish off the PR for
|
# NOTE(rob): this is a hack until we finish off the PR for
|
||||||
# merging TRTLLM kernels into the MK framework. Then we can
|
# merging TRTLLM kernels into the MK framework. Then we can
|
||||||
# query the MonolithicMK for the expected router logits.
|
# query the MonolithicMK for the expected router logits.
|
||||||
|
# NOTE(dbari): Use BF16 if routing is not Deepseek, e.g. Mistral Large 3
|
||||||
self.gate.set_out_dtype(
|
self.gate.set_out_dtype(
|
||||||
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
|
torch.float32
|
||||||
|
if self.experts.quant_method.is_monolithic
|
||||||
|
and self.experts.routing_method_type == RoutingMethodType.DeepSeekV3
|
||||||
|
else torch.bfloat16
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -1197,6 +1205,11 @@ class DeepseekV2Model(nn.Module):
|
|||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
|
if input_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Either input_ids or inputs_embeds must be provided "
|
||||||
|
"to DeepseekV2Model.forward"
|
||||||
|
)
|
||||||
hidden_states = self.embed_input_ids(input_ids)
|
hidden_states = self.embed_input_ids(input_ids)
|
||||||
residual = None
|
residual = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -75,16 +75,16 @@ class FlashInferMLABackend(MLACommonBackend):
|
|||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
device_capability: DeviceCapability,
|
device_capability: DeviceCapability,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
# FlashInfer MLA kernel requires qk_nope_head_dim == 128
|
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128]
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
if vllm_config.model_config is not None:
|
if vllm_config.model_config is not None:
|
||||||
hf_text_config = vllm_config.model_config.hf_text_config
|
hf_text_config = vllm_config.model_config.hf_text_config
|
||||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||||
if qk_nope_head_dim != 128:
|
if qk_nope_head_dim not in [64, 128]:
|
||||||
return (
|
return (
|
||||||
f"FlashInfer MLA kernel requires qk_nope_head_dim == 128, "
|
f"FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128], "
|
||||||
f"but got {qk_nope_head_dim}"
|
f"but got {qk_nope_head_dim}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user