committed by
Robert Shaw
parent
daa2784bb9
commit
7d98f09b1c
@@ -86,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
|
||||
return not moe_parallel_config.enable_eplb
|
||||
|
||||
|
||||
def is_supported_config_trtllm(
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
|
||||
def is_supported_config_trtllm_fp8(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
@@ -115,13 +131,17 @@ def is_supported_config_trtllm(
|
||||
return False, _make_reason("routing method")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason("activation format")
|
||||
elif not _supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype, moe_config.routing_method
|
||||
):
|
||||
return False, _make_reason("float32 router_logits with non-DeepSeekV3 routing")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
@@ -135,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int = int(RoutingMethodType.DeepSeekV3),
|
||||
routing_method_type: int,
|
||||
routed_scaling: float | None = 1.0,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
@@ -148,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
routing_logits = routing_logits.to(torch.float32)
|
||||
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(x.dtype)
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
@@ -175,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm,
|
||||
is_supported_config_trtllm_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
@@ -212,7 +212,7 @@ def select_fp8_moe_backend(
|
||||
|
||||
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
@@ -239,7 +239,7 @@ def select_fp8_moe_backend(
|
||||
]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
@@ -308,7 +308,7 @@ def select_fp8_moe_backend(
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
|
||||
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
int4_w4a16_moe_quant_config,
|
||||
int4_w4afp8_moe_quant_config,
|
||||
int8_w8a8_moe_quant_config,
|
||||
@@ -1072,17 +1071,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
e_score_correction_bias = (
|
||||
layer.e_score_correction_bias.to(x.dtype)
|
||||
if layer.e_score_correction_bias is not None
|
||||
else None
|
||||
)
|
||||
routing_method_type = layer.routing_method_type
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale,
|
||||
@@ -1096,7 +1087,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=routing_method_type,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
@@ -990,17 +989,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
e_score_correction_bias = (
|
||||
layer.e_score_correction_bias.to(x.dtype)
|
||||
if layer.e_score_correction_bias is not None
|
||||
else None
|
||||
)
|
||||
routing_method_type = layer.routing_method_type
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
@@ -1014,7 +1005,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=routing_method_type,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
|
||||
Reference in New Issue
Block a user