Files
vllm/vllm/model_executor/layers/fused_moe/prepare_finalize.py
Shu Wang c7d8724e78 [Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)
Signed-off-by: shuw <shuw@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-07-17 21:32:45 -07:00

82 lines
3.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None
def num_dispatchers(self) -> int:
return 1
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]],
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
if (extra_prepare_args is not None
and extra_prepare_args.get("skip_quant", True)):
# Skip quantization if explicitly requested
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
if (extra_finalize_args is not None
and extra_finalize_args.get("skip_weight_reduce", True)):
assert output.shape == fused_expert_output.shape
output.copy_(fused_expert_output)
else:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)