[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: amirkl94 <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
@@ -7,17 +7,27 @@ import torch
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
|
||||
FlashInferA2APrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNaiveEP,
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import (
|
||||
@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
allow_new_interface: bool = False,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
# NOTE(rob): we are migrating each quant_method to hold the MK
|
||||
# in all cases. The allow_new_interface=False flag allow us to fall
|
||||
# back to the old method for methods that have not yet been migrated.
|
||||
#
|
||||
# In old method:
|
||||
# * maybe_init_modular_kernel() calls this function. If we are
|
||||
# using no Dp/Ep or naive all2all, we return None this function
|
||||
# returns None and no ModularKernelMethod is created. If non-naive
|
||||
# all2all is used, this returns a PrepareAndFinalize object and
|
||||
# a ModularKernelMethod is created.
|
||||
# In new method:
|
||||
# * maybe_make_prepare_finalize() is called from the oracle. We
|
||||
# always return a PrepareAndFinalize object and the quant method
|
||||
# holds the ModularKernel.
|
||||
if not moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
if not allow_new_interface:
|
||||
return None
|
||||
|
||||
# For DP/TP case, fall back to naive P/F.
|
||||
if moe.moe_parallel_config.dp_size > 1:
|
||||
logger.info_once(
|
||||
"Detected DP deployment with no --enable-expert-parallel. "
|
||||
"Falling back to AllGather+ReduceScatter dispatch/combine."
|
||||
)
|
||||
return MoEPrepareAndFinalizeNaiveEP(
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=(
|
||||
get_ep_group().device_communicator.all2all_manager.world_size
|
||||
),
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||
|
||||
# TODO(rob): update this as part of the MoE refactor.
|
||||
assert not moe.use_flashinfer_cutlass_kernels, (
|
||||
"Must be created in modelopt.py or fp8.py"
|
||||
)
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize(
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
elif moe.use_fi_all2allv_kernels:
|
||||
assert quant_config is not None
|
||||
prepare_finalize = FlashInferA2APrepareAndFinalize(
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
elif moe.use_naive_all2all_kernels and allow_new_interface:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
|
||||
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
return prepare_finalize
|
||||
|
||||
@@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
@@ -862,6 +861,7 @@ class FusedMoEParallelConfig:
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
all2all_backend: str # all2all backend for MoE communication
|
||||
is_sequence_parallel: bool # whether sequence parallelism is used
|
||||
enable_eplb: bool # whether to enable expert load balancing
|
||||
|
||||
@property
|
||||
@@ -883,6 +883,12 @@ class FusedMoEParallelConfig:
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
||||
|
||||
@property
|
||||
def use_fi_all2allv_kernels(self):
|
||||
return (
|
||||
self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
|
||||
)
|
||||
|
||||
@property
|
||||
def use_batched_activation_format(self):
|
||||
return self.use_deepep_ll_kernels or self.use_pplx_kernels
|
||||
@@ -1014,6 +1020,7 @@ class FusedMoEParallelConfig:
|
||||
ep_rank=0,
|
||||
use_ep=False,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
|
||||
enable_eplb=vllm_parallel_config.enable_eplb,
|
||||
)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
@@ -1033,6 +1040,7 @@ class FusedMoEParallelConfig:
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
|
||||
enable_eplb=vllm_parallel_config.enable_eplb,
|
||||
)
|
||||
|
||||
@@ -1051,6 +1059,7 @@ class FusedMoEParallelConfig:
|
||||
use_ep=False,
|
||||
all2all_backend="naive",
|
||||
enable_eplb=False,
|
||||
is_sequence_parallel=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1145,12 +1154,9 @@ class FusedMoEConfig:
|
||||
return self.moe_parallel_config.use_mori_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
"""
|
||||
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
|
||||
"""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
|
||||
)
|
||||
def use_fi_all2allv_kernels(self):
|
||||
return self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
@property
|
||||
def use_naive_all2all_kernels(self):
|
||||
return self.moe_parallel_config.use_naive_all2all_kernels
|
||||
|
||||
@@ -103,7 +103,14 @@ def run_cutlass_moe_fp8(
|
||||
or a2_scale.size(0) == a1q.shape[0]
|
||||
), "Intermediate scale shape mismatch"
|
||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||
if expert_map is not None:
|
||||
|
||||
# NOTE(rob): the expert_map is used for the STANDARD case and
|
||||
# the batched format is used by the BATCHED case.
|
||||
# TODO(rob): update the MK interface to only pass the expert_map
|
||||
# during the STANDARD case to make this clearer across all kernels.
|
||||
if use_batched_format:
|
||||
assert expert_num_tokens is not None
|
||||
else:
|
||||
assert expert_num_tokens is None
|
||||
|
||||
# We have two modes: batched experts and non-batched experts.
|
||||
@@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
# needed for STANDARD activation format kernels in DP/EP mode.
|
||||
# Note that the BATCHED activation format does not use
|
||||
# the expert map for identifying experts.
|
||||
return not moe_parallel_config.use_all2all_kernels
|
||||
return not (
|
||||
moe_parallel_config.use_fi_all2allv_kernels
|
||||
or moe_parallel_config.use_deepep_ht_kernels
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
@@ -641,10 +651,8 @@ def run_cutlass_moe_fp4(
|
||||
|
||||
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@staticmethod
|
||||
def expects_unquantized_inputs(
|
||||
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
|
||||
) -> bool:
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -103,6 +103,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts: int,
|
||||
a1_scale: torch.Tensor | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool,
|
||||
) -> Callable:
|
||||
has_scales = token_scales is not None
|
||||
|
||||
@@ -174,6 +175,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_topk_weights,
|
||||
a1_scale,
|
||||
quant_config,
|
||||
defer_input_quant=defer_input_quant,
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
@@ -187,6 +189,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_topk_weights: torch.Tensor | None,
|
||||
a1_scale: torch.Tensor | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool,
|
||||
) -> mk.PrepareResultType:
|
||||
if event.event is not None:
|
||||
event.current_stream_wait()
|
||||
@@ -221,14 +224,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_num_tokens_per_expert_list, device=expert_x.device
|
||||
)
|
||||
|
||||
# Dispatch and Quant
|
||||
# DeepEP kernels only support dispatching block-quantized
|
||||
# activation scales.
|
||||
# Dispatch in bfloat16 and quantize afterwards
|
||||
if not quant_config.is_block_quantized:
|
||||
# * For non-block quant, dispatch in b16 and quantize now as
|
||||
# DeepEP kernels only support dispatching block scales.
|
||||
# * For expert kernels that require unquantized inputs,
|
||||
# defer quantization to FusedMoEExpertsPermuteUnpermute.
|
||||
if not quant_config.is_block_quantized and not defer_input_quant:
|
||||
# Quantize after dispatch.
|
||||
expert_x_scale = None
|
||||
if expert_x.numel() != 0:
|
||||
# TODO: support per_act_token_quant,
|
||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||
expert_x,
|
||||
a1_scale,
|
||||
@@ -257,6 +261,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.ReceiverType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@@ -266,8 +271,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
if quant_config.is_block_quantized:
|
||||
# Quant and Dispatch
|
||||
# * DeepEP only supports fp8 block scales so quantize
|
||||
# before the dispatch for these models.
|
||||
# * For all other quantization, dispatch after.
|
||||
# * For expert kernels that require unquantized inputs,
|
||||
# defer quantization to FusedMoEExpertsPermuteUnpermute.
|
||||
if quant_config.is_block_quantized and not defer_input_quant:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_scale,
|
||||
@@ -281,7 +290,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
else:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
a1_post_scale = quant_config.a1_scale
|
||||
a1_post_scale = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.quant_dtype == "nvfp4"
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
|
||||
return self._do_dispatch(
|
||||
tokens=a1q,
|
||||
@@ -291,6 +304,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config,
|
||||
defer_input_quant=defer_input_quant,
|
||||
)
|
||||
|
||||
def prepare(
|
||||
@@ -302,6 +316,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(
|
||||
a1,
|
||||
@@ -311,6 +326,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
defer_input_quant,
|
||||
)
|
||||
return receiver()
|
||||
|
||||
|
||||
@@ -242,7 +242,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
|
||||
f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||
@@ -344,7 +351,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
hook, receiver = self.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
All2AllManagerBase,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_dispatchers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def _apply_router_weight_on_input(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""Apply router weight on input if needed."""
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
self._apply_router_weight_on_input(
|
||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||
)
|
||||
global_num_tokens_cpu = get_local_sizes()
|
||||
top_k = topk_ids.size(1)
|
||||
|
||||
(self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
|
||||
flashinfer_alltoall_dispatch(
|
||||
self.all2all_manager,
|
||||
global_num_tokens_cpu,
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
top_k,
|
||||
num_experts,
|
||||
quant_config,
|
||||
defer_input_quant=defer_input_quant,
|
||||
)
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
top_k = topk_ids.size(1)
|
||||
token_count = output.shape[0]
|
||||
fused_expert_output = flashinfer_alltoall_combine(
|
||||
self.all2all_manager,
|
||||
fused_expert_output,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
alltoall_info=self.alltoall_info,
|
||||
)
|
||||
output.copy_(fused_expert_output)
|
||||
|
||||
|
||||
def flashinfer_alltoall_dispatch(
|
||||
all2all_manager: All2AllManagerBase,
|
||||
global_num_tokens_cpu: list[int],
|
||||
x: torch.Tensor,
|
||||
gs: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
|
||||
"FlashInfer AllToAll workspace not available"
|
||||
)
|
||||
|
||||
ep_rank = all2all_manager.rank
|
||||
ep_size = all2all_manager.world_size
|
||||
max_num_token = (
|
||||
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
|
||||
)
|
||||
orig_topk_weights_dtype = topk_weights.dtype
|
||||
alltoall_info, topk_ids, topk_weights, _ = (
|
||||
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
None,
|
||||
all2all_manager.prepare_workspace_tensor,
|
||||
max_num_token,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
num_experts,
|
||||
num_experts,
|
||||
top_k,
|
||||
)
|
||||
)
|
||||
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
|
||||
|
||||
if not defer_input_quant:
|
||||
x, x_sf = moe_kernel_quantize_input(
|
||||
x,
|
||||
gs,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x_sf,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
# Swizzle after the A2A if nvfp4.
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
if x_sf.element_size() == 1:
|
||||
x_sf = x_sf.view(torch.uint8)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
else:
|
||||
# Block-scale path: pass activations through without quantization
|
||||
x_sf = None
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
return alltoall_info, topk_ids, topk_weights, x, x_sf
|
||||
|
||||
|
||||
def flashinfer_alltoall_combine(
|
||||
all2all_manager: All2AllManagerBase,
|
||||
output: torch.Tensor,
|
||||
top_k: int,
|
||||
token_count: int,
|
||||
alltoall_info,
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
|
||||
"FlashInfer AllToAll workspace not available"
|
||||
)
|
||||
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
|
||||
output,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank=all2all_manager.rank,
|
||||
ep_size=all2all_manager.world_size,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
)
|
||||
@@ -78,16 +78,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# - skip input activation quantization (kernel applies scaling)
|
||||
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
|
||||
|
||||
@staticmethod
|
||||
def expects_unquantized_inputs(
|
||||
moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
|
||||
) -> bool:
|
||||
# NVFP4 TP kernels and FP8 block-quantized kernels apply
|
||||
# input quantization inside FusedMoEPermuteExpertsUnpermute.
|
||||
return (
|
||||
quant_config.use_nvfp4_w4a4
|
||||
and not moe_config.moe_parallel_config.use_all2all_kernels
|
||||
) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return self.quant_config.use_fp8_w8a8 and self.quant_config.is_block_quantized
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
@@ -144,10 +137,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# FLASHINFER_CUTLASS currently uses its down P/F, which does not
|
||||
# work with SP. This will be removed in follow up after we get
|
||||
# rid of the FlashInfer specific P/F function.
|
||||
return (
|
||||
moe_parallel_config.dp_size == 1
|
||||
or moe_parallel_config.dp_size == moe_parallel_config.ep_size
|
||||
)
|
||||
# TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
|
||||
return not moe_parallel_config.is_sequence_parallel
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
@@ -194,8 +185,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0,)
|
||||
# For TP, the quantization is fused with fused_moe call.
|
||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
|
||||
# For NVFP4, the output is stored in a packed int8 format,
|
||||
# so the actual hidden dim is 2x the size of K here.
|
||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
|
||||
# The workspace is determined by `aq`, since it comes after any
|
||||
# potential communication op and is involved in the expert computation.
|
||||
return (workspace1, workspace2, output_shape)
|
||||
|
||||
@@ -1,373 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
All2AllManagerBase,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.use_dp = use_dp
|
||||
self.local_tokens = None
|
||||
# Toggle for DeepSeek-style FP8 block-scale path where activations are
|
||||
# not quantized here and weight block scales are consumed by the kernel.
|
||||
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def _apply_router_weight_on_input(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""Apply router weight on input if needed."""
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
|
||||
class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize):
|
||||
"""FlashInfer implementation using AllToAll communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
|
||||
self.alltoall_info = None
|
||||
|
||||
# Initialize all2all_manager only for DP case
|
||||
self.all2all_manager = None
|
||||
if self.use_dp:
|
||||
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
self._apply_router_weight_on_input(
|
||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||
)
|
||||
|
||||
if not self.use_dp:
|
||||
# Non-DP case: quantize activations unless using block-scale path
|
||||
if not self.use_deepseek_fp8_block_scale:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
else:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
# DP case: use FlashInfer AllToAll
|
||||
global_num_tokens_cpu = get_local_sizes()
|
||||
top_k = topk_ids.size(1)
|
||||
|
||||
(self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
|
||||
flashinfer_alltoall_dispatch(
|
||||
self.all2all_manager,
|
||||
global_num_tokens_cpu,
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
top_k,
|
||||
num_experts,
|
||||
quant_config,
|
||||
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||
)
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if self.use_dp:
|
||||
top_k = topk_ids.size(1)
|
||||
token_count = output.shape[0]
|
||||
fused_expert_output = flashinfer_alltoall_combine(
|
||||
self.all2all_manager,
|
||||
fused_expert_output,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
alltoall_info=self.alltoall_info,
|
||||
)
|
||||
output.copy_(fused_expert_output)
|
||||
|
||||
|
||||
class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize):
|
||||
def __init__(
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
self._apply_router_weight_on_input(
|
||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||
)
|
||||
is_nvfp4 = quant_config.quant_dtype == "nvfp4"
|
||||
if not self.use_dp and is_nvfp4:
|
||||
return a1, None, None, topk_ids, topk_weights
|
||||
|
||||
if not self.use_deepseek_fp8_block_scale:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
else:
|
||||
# Block-scale path: pass activations through, omit per-token scales
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
|
||||
if self.use_dp:
|
||||
# Build gather list conditionally - omit a1q_scale if None
|
||||
# (block-scale path)
|
||||
gather_list = [topk_weights, topk_ids, a1q]
|
||||
if a1q_scale is not None:
|
||||
gather_list.append(a1q_scale)
|
||||
gathered = get_dp_group().all_gatherv(
|
||||
gather_list,
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
topk_weights, topk_ids, a1q, a1q_scale = gathered
|
||||
else:
|
||||
gathered = get_dp_group().all_gatherv(
|
||||
gather_list,
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
topk_weights, topk_ids, a1q = gathered
|
||||
a1q_scale = None
|
||||
|
||||
if is_nvfp4 and a1q_scale is not None:
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP)
|
||||
|
||||
if self.use_dp:
|
||||
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||
fused_expert_output, dim=0, sizes=get_local_sizes()
|
||||
)
|
||||
output.copy_(fused_expert_output)
|
||||
|
||||
|
||||
def flashinfer_alltoall_dispatch(
|
||||
all2all_manager: All2AllManagerBase,
|
||||
global_num_tokens_cpu: list[int],
|
||||
x: torch.Tensor,
|
||||
gs: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
|
||||
"FlashInfer AllToAll workspace not available"
|
||||
)
|
||||
|
||||
ep_rank = all2all_manager.rank
|
||||
ep_size = all2all_manager.world_size
|
||||
max_num_token = (
|
||||
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
|
||||
)
|
||||
orig_topk_weights_dtype = topk_weights.dtype
|
||||
alltoall_info, topk_ids, topk_weights, _ = (
|
||||
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
None,
|
||||
all2all_manager.prepare_workspace_tensor,
|
||||
max_num_token,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
num_experts,
|
||||
num_experts,
|
||||
top_k,
|
||||
)
|
||||
)
|
||||
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
|
||||
|
||||
if not use_deepseek_fp8_block_scale:
|
||||
x, x_sf = moe_kernel_quantize_input(
|
||||
x,
|
||||
gs,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False, # delay swizzle to after comm
|
||||
)
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x_sf,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
else:
|
||||
# Block-scale path: pass activations through without quantization
|
||||
x_sf = None
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
return alltoall_info, topk_ids, topk_weights, x, x_sf
|
||||
|
||||
|
||||
def flashinfer_alltoall_combine(
|
||||
all2all_manager: All2AllManagerBase,
|
||||
output: torch.Tensor,
|
||||
top_k: int,
|
||||
token_count: int,
|
||||
alltoall_info,
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
assert all2all_manager.ensure_alltoall_workspace_initialized(), (
|
||||
"FlashInfer AllToAll workspace not available"
|
||||
)
|
||||
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
|
||||
output,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank=all2all_manager.rank,
|
||||
ep_size=all2all_manager.world_size,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
||||
def create_flashinfer_prepare_finalize(
|
||||
use_dp: bool,
|
||||
use_nvfp4: bool = False,
|
||||
enable_alltoallv: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
|
||||
"""Factory function to create the appropriate FlashInfer implementation."""
|
||||
|
||||
if use_dp:
|
||||
if enable_alltoallv:
|
||||
assert use_nvfp4
|
||||
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
use_dp=True,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
)
|
||||
else:
|
||||
# CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
|
||||
# in a single call with the MoE experts kernel.
|
||||
defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
|
||||
return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
|
||||
@@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0)
|
||||
|
||||
@@ -593,7 +593,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
@property
|
||||
def quant_type_id(self) -> int:
|
||||
|
||||
@@ -1951,7 +1951,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -5,6 +5,7 @@ from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -26,6 +27,19 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
super().__init__()
|
||||
self.moe: FusedMoEConfig = moe
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
@property
|
||||
def supports_internal_mk(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None
|
||||
|
||||
@property
|
||||
def mk_owns_shared_expert(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
@@ -91,6 +105,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.moe_mk is not None:
|
||||
return self.moe_mk.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
@property
|
||||
|
||||
@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
):
|
||||
super().__init__(old_quant_method.moe)
|
||||
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||
self.fused_experts = experts
|
||||
self.moe_mk = experts
|
||||
self.disable_expert_map = getattr(
|
||||
old_quant_method,
|
||||
"disable_expert_map",
|
||||
not self.fused_experts.supports_expert_map(),
|
||||
not self.moe_mk.supports_expert_map(),
|
||||
)
|
||||
self.old_quant_method = old_quant_method
|
||||
assert not self.old_quant_method.is_monolithic
|
||||
@@ -57,10 +57,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return self.fused_experts.prepare_finalize.topk_indices_dtype()
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return self.old_quant_method.supports_eplb
|
||||
@@ -96,7 +92,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.fused_experts(
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -571,9 +571,6 @@ class FusedMoE(CustomOp):
|
||||
device=vllm_config.device_config.device,
|
||||
routing_method=self.routing_method_type,
|
||||
)
|
||||
self.moe_config_use_flashinfer_cutlass_kernels = (
|
||||
self.moe_config.use_flashinfer_cutlass_kernels
|
||||
)
|
||||
if self.use_mori_kernels:
|
||||
assert self.rocm_aiter_fmoe_enabled, (
|
||||
"Mori needs to be used with aiter fused_moe for now."
|
||||
@@ -646,6 +643,11 @@ class FusedMoE(CustomOp):
|
||||
# This is called after all weight loading and post-processing, so it
|
||||
# should be safe to swap out the quant_method.
|
||||
def maybe_init_modular_kernel(self) -> None:
|
||||
# NOTE(rob): WIP refactor. For quant methods that own the MK
|
||||
# we create the MK during process_weights_after_loading.
|
||||
if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
|
||||
return None
|
||||
|
||||
self.ensure_moe_quant_config_init()
|
||||
# routing_tables only needed for round-robin expert placement with
|
||||
# DeepEP all2all backend.
|
||||
@@ -728,14 +730,6 @@ class FusedMoE(CustomOp):
|
||||
def use_mori_kernels(self):
|
||||
return self.moe_parallel_config.use_mori_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return (
|
||||
self.moe_quant_config is not None
|
||||
and self.moe_quant_config.quant_dtype == "nvfp4"
|
||||
and self.moe_config_use_flashinfer_cutlass_kernels
|
||||
)
|
||||
|
||||
@property
|
||||
def use_marlin_kernels(self):
|
||||
return getattr(self.quant_method, "use_marlin", False)
|
||||
@@ -746,7 +740,7 @@ class FusedMoE(CustomOp):
|
||||
self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_parallel_config.use_mori_kernels
|
||||
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
|
||||
or self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
|
||||
@property
|
||||
@@ -1532,7 +1526,7 @@ class FusedMoE(CustomOp):
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
isinstance(self.quant_method, FusedMoEModularMethod)
|
||||
and self.quant_method.fused_experts.output_is_reduced()
|
||||
and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
@@ -1765,7 +1759,7 @@ class FusedMoE(CustomOp):
|
||||
self.ensure_dp_chunking_init()
|
||||
|
||||
has_separate_shared_experts = (
|
||||
not isinstance(self.quant_method, FusedMoEModularMethod)
|
||||
not self.quant_method.mk_owns_shared_expert
|
||||
and self.shared_experts is not None
|
||||
)
|
||||
|
||||
@@ -1789,8 +1783,10 @@ class FusedMoE(CustomOp):
|
||||
hidden_states, router_logits, has_separate_shared_experts
|
||||
)
|
||||
|
||||
do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
|
||||
self.quant_method, FusedMoEModularMethod
|
||||
# NOTE(rob): once we finish migrating all the quant methods to use
|
||||
# MKs, we can remove the naive dispatch/combine path from here.
|
||||
do_naive_dispatch_combine = (
|
||||
self.dp_size > 1 and not self.quant_method.supports_internal_mk
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
@@ -1818,7 +1814,7 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch(
|
||||
dispatch_res = get_ep_group().dispatch_router_logits(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.is_sequence_parallel,
|
||||
|
||||
@@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool,
|
||||
) -> PrepareResultType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
@@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
@@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool,
|
||||
) -> tuple[Callable, ReceiverType] | ReceiverType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
@@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
space to the local expert space of the expert parallel shard.
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
@@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@staticmethod
|
||||
def expects_unquantized_inputs(
|
||||
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
|
||||
) -> bool:
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
"""
|
||||
Whether or not the PrepareFinalize should defer input quantization
|
||||
in the prepare step. If True, then the Experts kernel will
|
||||
@@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
@@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
|
||||
)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
|
||||
@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
"""
|
||||
Returns a tuple of:
|
||||
@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
"""
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
assert not apply_router_weight_on_input, (
|
||||
"mori does not support apply_router_weight_on_input=True now."
|
||||
)
|
||||
|
||||
@@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
@@ -465,68 +465,52 @@ def make_fp8_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def make_fp8_moe_kernel_for_mkm(
|
||||
def make_fp8_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> tuple[mk.FusedMoEModularKernel, bool]:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
|
||||
|
||||
# Create Experts.
|
||||
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens is not None
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
quant_config=moe_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
else:
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
|
||||
|
||||
def make_fp8_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
) -> tuple[mk.FusedMoEModularKernel, bool]:
|
||||
# TODO(rob): unify after we merge tp and dp/ep.
|
||||
if (
|
||||
moe_config.moe_parallel_config.use_all2all_kernels
|
||||
and moe_config.moe_parallel_config.all2all_backend
|
||||
not in ["allgather_reducescatter", "naive"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Fp8 Oracle should not create non-naive A2A P/F. "
|
||||
"This should happen via the ModularKernelMethod."
|
||||
)
|
||||
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP(
|
||||
defer_input_quant=experts_cls.expects_unquantized_inputs(
|
||||
moe_config, moe_quant_config
|
||||
),
|
||||
)
|
||||
|
||||
# Create Experts.
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explict in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=None,
|
||||
shared_experts=(
|
||||
shared_experts
|
||||
if moe_config.moe_parallel_config.use_all2all_kernels
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ import torch
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
nvfp4_moe_quant_config,
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_supported_config_trtllm,
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
@@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel_for_mkm(
|
||||
def make_nvfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
|
||||
|
||||
# Create Experts.
|
||||
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens is not None
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
quant_config=moe_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
else:
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
# TODO(rob): unify after we merge tp and dp/ep.
|
||||
if (
|
||||
moe_config.moe_parallel_config.use_all2all_kernels
|
||||
and moe_config.moe_parallel_config.all2all_backend
|
||||
not in ["allgather_reducescatter", "naive"]
|
||||
):
|
||||
raise ValueError(
|
||||
"NvFP4 Oracle should not create non-naive A2A P/F. "
|
||||
"This should happen via the ModularKernelMethod."
|
||||
)
|
||||
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP(
|
||||
defer_input_quant=experts_cls.expects_unquantized_inputs(
|
||||
moe_config, moe_quant_config
|
||||
),
|
||||
)
|
||||
|
||||
# Create Experts.
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explict in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=None,
|
||||
shared_experts=(
|
||||
shared_experts
|
||||
if moe_config.moe_parallel_config.use_all2all_kernels
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
@@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
hook, receiver = self.prepare_async(
|
||||
a1,
|
||||
@@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
defer_input_quant=defer_input_quant,
|
||||
)
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
@@ -4,19 +4,133 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quantization to the MoE kernel.
|
||||
use_nvfp4 = quant_config.use_nvfp4_w4a4
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# Skip gathering scales if we have static quantization
|
||||
# (the scale is a scalar, replicated on all ranks) or
|
||||
# if quantization is deferred.
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
res = get_ep_group().dispatch(
|
||||
a1q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
if skip_gather_scales:
|
||||
a1q, topk_weights, topk_ids = res
|
||||
else:
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
assert scales is not None and len(scales) == 1
|
||||
a1q_scale = scales[0]
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
assert a1q_scale is not None
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
|
||||
out = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
output.copy_(
|
||||
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(self, defer_input_quant: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.defer_input_quant = defer_input_quant
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
@@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if self.defer_input_quant:
|
||||
if defer_input_quant:
|
||||
return a1, None, None, None, None
|
||||
|
||||
input_sf = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.use_nvfp4_w4a4
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_scale,
|
||||
input_sf,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
|
||||
@@ -287,17 +287,14 @@ def rocm_aiter_fused_experts(
|
||||
|
||||
|
||||
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def expects_unquantized_inputs(
|
||||
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
|
||||
) -> bool:
|
||||
# AITER fused MoE kernels handle input quantization internally.
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
return rocm_aiter_ops.is_fused_moe_enabled()
|
||||
@@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
def supports_expert_map(self):
|
||||
return True
|
||||
|
||||
@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
|
||||
use_overlapped
|
||||
and not (
|
||||
(self.enable_eplb and backend != "allgather_reducescatter")
|
||||
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||
or self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_kernel_for_mkm,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
@@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_mxfp4_moe_quant_config,
|
||||
make_nvfp4_moe_kernel,
|
||||
make_nvfp4_moe_kernel_for_mkm,
|
||||
make_nvfp4_moe_quant_config,
|
||||
select_nvfp4_moe_backend,
|
||||
)
|
||||
@@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
@@ -243,7 +240,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.group_size = 32
|
||||
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
|
||||
self.experts_cls = MarlinExperts
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -320,7 +316,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
layer.w13_weight_packed.data, requires_grad=False
|
||||
)
|
||||
@@ -335,10 +331,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config is not None:
|
||||
self.kernel = make_nvfp4_moe_kernel(
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -348,8 +346,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -380,19 +378,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
activation_key=None if use_a16 else kNvfp4Dynamic,
|
||||
)
|
||||
|
||||
# Delay creation of the kernel until after process-weights.
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
|
||||
self.nvfp4_backend
|
||||
)
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.kernel is not None:
|
||||
return self.kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -506,7 +495,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||
"""
|
||||
Convert NVFP4 MoE weights into kernel format and setup the kernel.
|
||||
"""
|
||||
@@ -572,48 +561,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config and (
|
||||
(not self.moe.moe_parallel_config.use_all2all_kernels)
|
||||
or self.moe.moe_parallel_config.use_naive_all2all_kernels
|
||||
):
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.kernel = make_nvfp4_moe_kernel(
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
# For no-EP case, don't use the MKM framework.
|
||||
if not self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
self.moe,
|
||||
use_deepseek_fp8_block_scale=False,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.experts_cls is not None
|
||||
return make_nvfp4_moe_kernel_for_mkm(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
experts_cls=self.experts_cls,
|
||||
prepare_finalize=prepare_finalize,
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@@ -684,8 +658,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -759,15 +733,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
allow_vllm_cutlass=True,
|
||||
)
|
||||
|
||||
# Delay creation of the kernel until after process-weights.
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.kernel is not None:
|
||||
return self.kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -927,7 +892,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||
# Allow for accessing weights and scales in standard way.
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
@@ -989,49 +954,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config and (
|
||||
(not self.moe.moe_parallel_config.use_all2all_kernels)
|
||||
or self.moe.moe_parallel_config.use_naive_all2all_kernels
|
||||
):
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# For no-EP case, don't use the MKM framework.
|
||||
if not self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.experts_cls is not None
|
||||
return make_fp8_moe_kernel_for_mkm(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
experts_cls=self.experts_cls,
|
||||
prepare_finalize=prepare_finalize,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@@ -1120,8 +1070,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_kernel_for_mkm,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
allow_vllm_cutlass=False,
|
||||
)
|
||||
|
||||
# Delay creation of the kernel until after process-weights.
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.kernel is not None:
|
||||
return self.kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: Module,
|
||||
layer: FusedMoE,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config and (
|
||||
(not self.moe.moe_parallel_config.use_all2all_kernels)
|
||||
or self.moe.moe_parallel_config.use_naive_all2all_kernels
|
||||
):
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# For no-EP case, don't use the MKM framework.
|
||||
if not self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.experts_cls is not None
|
||||
return make_fp8_moe_kernel_for_mkm(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
experts_cls=self.experts_cls,
|
||||
prepare_finalize=prepare_finalize,
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel is not None
|
||||
assert self.moe_mk is not None
|
||||
assert not self.is_monolithic
|
||||
return self.kernel(
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_kernel_for_mkm,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
@@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_nvfp4_moe_kernel,
|
||||
make_nvfp4_moe_kernel_for_mkm,
|
||||
make_nvfp4_moe_quant_config,
|
||||
select_nvfp4_moe_backend,
|
||||
)
|
||||
@@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
@@ -739,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
activation_key=kFp8StaticTensorSym,
|
||||
)
|
||||
|
||||
# Delay creation of the kernel until after process-weights.
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.kernel is not None:
|
||||
return self.kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
# TRT LLM not supported with all2all yet.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# For no-EP case, don't use the MKM framework.
|
||||
if not self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
self.moe,
|
||||
use_deepseek_fp8_block_scale=False,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.experts_cls is not None
|
||||
return make_fp8_moe_kernel_for_mkm(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
experts_cls=self.experts_cls,
|
||||
prepare_finalize=prepare_finalize,
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
@@ -863,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
@@ -893,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@@ -998,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -1379,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
activation_key=kNvfp4Dynamic,
|
||||
)
|
||||
|
||||
# Delay creation of the kernel until after process-weights.
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
|
||||
self.nvfp4_backend
|
||||
)
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.kernel is not None:
|
||||
return self.kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
# For no-EP case, don't use the MKM framework.
|
||||
if not self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
# For now, fp4 moe only works with the flashinfer dispatcher.
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
self.moe
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.experts_cls is not None
|
||||
return make_nvfp4_moe_kernel_for_mkm(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
experts_cls=self.experts_cls,
|
||||
prepare_finalize=prepare_finalize,
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def uses_weight_scale_2_pattern(self) -> bool:
|
||||
@@ -1547,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||
"""
|
||||
Convert NVFP4 MoE weights into kernel format and setup the kernel.
|
||||
"""
|
||||
@@ -1599,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config and (
|
||||
(not self.moe.moe_parallel_config.use_all2all_kernels)
|
||||
or self.moe.moe_parallel_config.use_naive_all2all_kernels
|
||||
):
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.kernel = make_nvfp4_moe_kernel(
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1708,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
@@ -42,7 +39,6 @@ __all__ = [
|
||||
"is_flashinfer_fp4_cutlass_moe_available",
|
||||
"is_flashinfer_fp4_cutedsl_moe_available",
|
||||
"reorder_w1w3_to_w3w1",
|
||||
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
|
||||
]
|
||||
|
||||
#
|
||||
@@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1(
|
||||
)
|
||||
|
||||
|
||||
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1
|
||||
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
|
||||
)
|
||||
|
||||
|
||||
def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
# args_dequant,
|
||||
# args,
|
||||
|
||||
@@ -4,15 +4,8 @@ from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
@@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi(
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
# Propagate block-scale flag so prepare/finalize can skip act quantization
|
||||
# and inform the kernel to consume per-block weight scales.
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
)
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
|
||||
@@ -14,7 +14,6 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
@@ -169,9 +168,10 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
# modular kernels could invoke deep_gemm_moe_fp8
|
||||
return True
|
||||
|
||||
mk: FusedMoEModularKernel = module.quant_method.fused_experts
|
||||
# Further check if the ModularKernel implementation uses the DeepGemmExperts
|
||||
return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
|
||||
return isinstance(
|
||||
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
|
||||
)
|
||||
|
||||
|
||||
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
|
||||
|
||||
Reference in New Issue
Block a user