[MoE Refactor][4/N] Marlin Fp8 Mk (#31036)
This commit is contained in:
@@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
@@ -95,7 +95,6 @@ from vllm.model_executor.parameter import (
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
@@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
@@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
layer.w13_weight.data = w13_weight.data
|
||||
|
||||
if self.use_marlin:
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
@@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self.use_inplace = False
|
||||
|
||||
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
|
||||
elif self.fp8_backend in [
|
||||
Fp8MoeBackend.DEEPGEMM,
|
||||
Fp8MoeBackend.TRITON,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
]:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
@@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
config = self.get_fused_moe_quant_config(layer)
|
||||
assert config is not None
|
||||
self.moe_quant_config = config
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(
|
||||
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
moe_kernel = (
|
||||
MarlinExperts(quant_config=self.moe_quant_config)
|
||||
if use_marlin
|
||||
else TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
),
|
||||
allow_deep_gemm=allow_deep_gemm,
|
||||
)
|
||||
)
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(), moe_kernel
|
||||
)
|
||||
self.use_inplace = True
|
||||
|
||||
@@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
current_platform.is_xpu()
|
||||
or self.rocm_aiter_moe_enabled
|
||||
or self.use_marlin
|
||||
self.rocm_aiter_moe_enabled
|
||||
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return None
|
||||
@@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
assert (
|
||||
self.fp8_backend != Fp8MoeBackend.MARLIN
|
||||
) and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||
)
|
||||
|
||||
@@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=(
|
||||
@@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
# TODO(rob): convert this to MK.
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
result = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
else:
|
||||
result = self.kernel(
|
||||
x,
|
||||
@@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# Rushuffle weights for MARLIN if needed.
|
||||
if self.use_marlin:
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user