[MoE Refactor][4/N] Marlin Fp8 Mk (#31036)

This commit is contained in:
Robert Shaw
2025-12-21 12:37:42 -05:00
committed by GitHub
parent 93cabc417c
commit b471092d3a
4 changed files with 85 additions and 63 deletions

View File

@@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
LinearBase,
@@ -95,7 +95,6 @@ from vllm.model_executor.parameter import (
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
@@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
@@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
layer.w13_weight.data = w13_weight.data
if self.use_marlin:
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
@@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
self.use_inplace = False
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
elif self.fp8_backend in [
Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN,
]:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
@@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
moe_kernel = (
MarlinExperts(quant_config=self.moe_quant_config)
if use_marlin
else TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
allow_deep_gemm=allow_deep_gemm,
)
)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), moe_kernel
)
self.use_inplace = True
@@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
current_platform.is_xpu()
or self.rocm_aiter_moe_enabled
or self.use_marlin
self.rocm_aiter_moe_enabled
or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
@@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts,
)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
assert (
self.fp8_backend != Fp8MoeBackend.MARLIN
) and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet."
)
@@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
return None
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
block_shape=self.weight_block_size,
)
return fp8_w8a8_moe_quant_config(
w1_scale=(
@@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
elif self.use_marlin:
# TODO(rob): convert this to MK.
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
result = fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
else:
result = self.kernel(
x,
@@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed.
if self.use_marlin:
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)