[NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@@ -23,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
@@ -145,7 +149,7 @@ class Fp8Config(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self, layer.moe_config)
|
||||
return Fp8MoEMethod(self, layer)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
@@ -482,16 +486,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
self.flashinfer_moe_enabled = False
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
"Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
|
||||
self.flashinfer_moe_enabled = True
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
@@ -531,6 +539,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
layer=self.layer,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@@ -678,7 +700,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||
layer.w2_input_scale)
|
||||
elif self.flashinfer_moe_enabled:
|
||||
elif self.flashinfer_moe_backend is not None:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
# applied on different half for flashinfer vs vllm
|
||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||
@@ -686,9 +708,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale_inv.data)
|
||||
w2_weight = layer.w2_weight.data
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
||||
if not self.block_quant:
|
||||
register_moe_scaling_factors(layer)
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
else:
|
||||
w13_weight = layer.w13_weight.data
|
||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||
@@ -834,6 +853,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
# applied on different half for flashinfer vs vllm
|
||||
assert not self.block_quant
|
||||
register_moe_scaling_factors(layer)
|
||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if self.flashinfer_moe_backend == \
|
||||
FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
layer.w13_weight.data = w13_weight.data
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
@@ -892,6 +922,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
per_act_token_quant=False,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
self.layer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
else:
|
||||
logger.debug(
|
||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||
@@ -930,25 +967,66 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
if not self.flashinfer_moe_enabled:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
if self.block_quant:
|
||||
assert (renormalize and use_grouped_topk
|
||||
and custom_routing_function is None)
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
routed_scaling=1.0,
|
||||
)
|
||||
else:
|
||||
assert (not renormalize
|
||||
and custom_routing_function is not None)
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
@@ -988,63 +1066,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
elif self.flashinfer_moe_enabled:
|
||||
assert activation == 'silu'
|
||||
assert scoring_func == 'sigmoid'
|
||||
if self.block_quant:
|
||||
assert (renormalize and use_grouped_topk
|
||||
and custom_routing_function is None)
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert self.block_quant is None
|
||||
assert (not renormalize and custom_routing_function is not None)
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
routed_scaling=1.0,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
assert (not renormalize
|
||||
and custom_routing_function is not None)
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
elif self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -27,8 +26,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||
@@ -49,11 +51,6 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
|
||||
class FlashinferMoeBackend(Enum):
|
||||
TENSORRT_LLM = "TensorRT-LLM"
|
||||
CUTLASS = "CUTLASS"
|
||||
|
||||
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
"""Config class for ModelOpt FP8."""
|
||||
|
||||
@@ -179,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptFp8MoEMethod(self, layer.moe_config)
|
||||
return ModelOptFp8MoEMethod(self, layer)
|
||||
return None
|
||||
|
||||
|
||||
@@ -278,18 +275,49 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptFp8Config,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> None:
|
||||
super().__init__(moe)
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported)
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.flashinfer_moe_enabled = False
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
|
||||
self.flashinfer_moe_enabled = True
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.fused_experts is not None or \
|
||||
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
layer=self.layer,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
self.layer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -433,11 +461,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
|
||||
requires_grad=False)
|
||||
|
||||
if self.flashinfer_moe_enabled:
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
register_moe_scaling_factors(layer)
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -461,14 +490,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||
|
||||
if self.flashinfer_moe_enabled:
|
||||
assert activation == 'silu'
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert not renormalize
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
@@ -495,6 +523,36 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not renormalize
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
return fused_experts(
|
||||
@@ -951,20 +1009,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.flashinfer_moe_backend = None
|
||||
|
||||
if self.allow_flashinfer:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend == "throughput":
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
logger.info_once("Using FlashInfer CUTLASS kernels for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
else:
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}")
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
" for ModelOptNvFp4FusedMoE.")
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
|
||||
@@ -1,9 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashinferMoeBackend(Enum):
|
||||
TENSORRT_LLM = "TensorRT-LLM"
|
||||
CUTLASS = "CUTLASS"
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
|
||||
@@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter(
|
||||
'output2_scales_scalar',
|
||||
torch.nn.Parameter(output2_scales, requires_grad=False))
|
||||
layer.register_parameter(
|
||||
'w2_input_scale_inv',
|
||||
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False))
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
use_dp, a1_gscale=layer.w13_input_scale)
|
||||
|
||||
|
||||
def select_cutlass_fp8_gemm_impl(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
|
||||
if moe is not None:
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
)
|
||||
|
||||
assert out_dtype is not None, (
|
||||
"If moe config is None, out_dtype must be passed")
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=out_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_cutlass_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
|
||||
layer=layer),
|
||||
select_cutlass_fp8_gemm_impl(moe=None,
|
||||
layer=layer,
|
||||
out_dtype=hidden_states.dtype))
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend == "throughput":
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
return FlashinferMoeBackend.TENSORRT_LLM
|
||||
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}")
|
||||
|
||||
Reference in New Issue
Block a user