[MoE][Refactor] Remove most arguments to FusedMoEMethodBase.apply (#29066)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
@@ -1242,41 +1241,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
assert layer.activation == "silu", (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
e_score_correction_bias = (
|
||||
e_score_correction_bias.to(x.dtype)
|
||||
if e_score_correction_bias is not None
|
||||
layer.e_score_correction_bias.to(x.dtype)
|
||||
if layer.e_score_correction_bias is not None
|
||||
else None
|
||||
)
|
||||
routing_method_type = layer.routing_method_type
|
||||
@@ -1290,29 +1268,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=routing_method_type,
|
||||
routed_scaling=routed_scaling_factor,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
assert not renormalize and custom_routing_function is not None
|
||||
assert (
|
||||
not layer.renormalize and layer.custom_routing_function is not None
|
||||
)
|
||||
result = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
select_result = layer.select_experts(
|
||||
@@ -1333,13 +1313,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
result = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -1352,20 +1334,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
assert layer.activation == "silu", (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
if not self.block_quant:
|
||||
assert not renormalize and custom_routing_function is not None
|
||||
assert scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
||||
assert (
|
||||
not layer.renormalize and layer.custom_routing_function is not None
|
||||
)
|
||||
assert layer.scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
|
||||
)
|
||||
# Delegate to CUTLASS FlashInfer path; function already bound with
|
||||
# use_deepseek_fp8_block_scale for block-quant when applicable
|
||||
@@ -1375,10 +1359,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@@ -1390,10 +1374,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
|
||||
Reference in New Issue
Block a user