[MoE Refactor][16/N] Apply Refactor to NVFP4 (#31692)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -72,7 +72,6 @@ if HAS_TRITON:
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
CutlassExpertsW4A8Fp8,
|
||||
cutlass_moe_fp4,
|
||||
cutlass_moe_w4a8_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
@@ -95,7 +94,6 @@ if HAS_TRITON:
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"GroupedTopk",
|
||||
"cutlass_moe_fp4",
|
||||
"cutlass_moe_w4a8_fp8",
|
||||
"CutlassExpertsFp8",
|
||||
"CutlassBatchedExpertsFp8",
|
||||
|
||||
@@ -336,6 +336,10 @@ class FusedMoEQuantConfig:
|
||||
def use_int4_w4a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == "int4"
|
||||
|
||||
@property
|
||||
def use_nvfp4_w4a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == "nvfp4"
|
||||
|
||||
@property
|
||||
def ocp_mx_scheme(self) -> str | None:
|
||||
if not hasattr(self, "_ocp_mx_scheme"):
|
||||
@@ -690,6 +694,25 @@ def nvfp4_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def nvfp4_w4a16_moe_quant_config(
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-but activations and nvp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
quant_dtype=None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
weight_dtype="nvfp4",
|
||||
)
|
||||
|
||||
|
||||
def int4_w4a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
|
||||
@@ -706,68 +706,6 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
|
||||
def cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert expert_map is None, (
|
||||
"Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4."
|
||||
)
|
||||
|
||||
# TODO(bnell): this feels a bit hacky
|
||||
# NVFP4 requires two levels of quantization, which involves
|
||||
# computing some scaling factors dynamically. This makes it
|
||||
# incompatible with the typical prepare -> MoE -> finalize
|
||||
# pipeline. Move the quantization logic into the MoE body.
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=None, # skip quantization in prepare/finalize
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
per_out_ch_quant=quant_config.per_out_ch_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
g1_alphas=quant_config.g1_alphas,
|
||||
g2_alphas=quant_config.g2_alphas,
|
||||
a1_gscale=quant_config.a1_gscale,
|
||||
a2_gscale=quant_config.a2_gscale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp4(
|
||||
max_experts_per_worker=e,
|
||||
out_dtype=a.dtype,
|
||||
quant_config=quant_config,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
|
||||
return fn(
|
||||
hidden_states=a,
|
||||
w1=w1_fp4,
|
||||
w2=w2_fp4,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
# W4A8
|
||||
def run_cutlass_moe_w4a8_fp8(
|
||||
output: torch.Tensor,
|
||||
|
||||
@@ -335,42 +335,3 @@ def flashinfer_cutedsl_moe_masked(
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
) # in logical [m, k, l]
|
||||
out = out.permute(2, 0, 1)
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_fp4(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
|
||||
FlashInferCuteDSLExperts(
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@@ -355,21 +355,17 @@ def create_flashinfer_prepare_finalize(
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
|
||||
"""Factory function to create the appropriate FlashInfer implementation."""
|
||||
# TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP
|
||||
# once we complete the FP8 refactor.
|
||||
if use_nvfp4:
|
||||
if enable_alltoallv:
|
||||
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
|
||||
else:
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
||||
|
||||
# FP8 DP path currently supported via AllGather.
|
||||
if use_dp:
|
||||
if enable_alltoallv:
|
||||
assert use_nvfp4
|
||||
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
use_dp=True,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
)
|
||||
else:
|
||||
# NOTE(rob): CUTLASS FP8 block quant executes the input
|
||||
# quantzation and grouped gemm in a single kernel.
|
||||
return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale)
|
||||
# CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
|
||||
# in a single call with the MoE experts kernel.
|
||||
defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
|
||||
return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
|
||||
|
||||
@@ -540,9 +540,10 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
assert (
|
||||
quant_config.use_mxfp4_w4a16
|
||||
or quant_config.use_nvfp4_w4a16
|
||||
or quant_config.use_int4_w4a16
|
||||
or quant_config.use_fp8_w8a16
|
||||
), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
|
||||
), "Supports only {mxfp,nvfp,int}4_w4a16 or fp8_w8a16"
|
||||
self.w13_g_idx = w13_g_idx
|
||||
self.w2_g_idx = w2_g_idx
|
||||
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
|
||||
@@ -555,7 +556,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
return scalar_types.uint4b8.id
|
||||
elif self.quant_config.use_mxfp4_w4a16:
|
||||
elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16:
|
||||
return scalar_types.float4_e2m1f.id
|
||||
elif (
|
||||
self.quant_config.use_fp8_w8a16
|
||||
@@ -692,6 +693,8 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
gating_output=None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_scale1=self.g1_alphas,
|
||||
global_scale2=self.g2_alphas,
|
||||
quant_type_id=self.quant_type_id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
|
||||
@@ -38,9 +38,6 @@ from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimula
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
is_flashinfer_supporting_global_sf,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
@@ -1125,14 +1122,9 @@ class FusedMoE(CustomOp):
|
||||
global_expert_id = expert_id
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id)
|
||||
|
||||
allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False)
|
||||
moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None)
|
||||
|
||||
use_global_sf = (
|
||||
allow_flashinfer
|
||||
and is_flashinfer_supporting_global_sf(moe_backend)
|
||||
getattr(self.quant_method, "use_global_sf", False)
|
||||
and "input_scale" in weight_name
|
||||
and quant_method_name == "ModelOptNvFp4FusedMoE"
|
||||
)
|
||||
|
||||
if expert_id == -1 and not use_global_sf:
|
||||
|
||||
280
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
Normal file
280
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
nvfp4_moe_quant_config,
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_flashinfer_fp4_cutedsl_moe_available,
|
||||
is_flashinfer_fp4_cutlass_moe_available,
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
is_fp4_marlin_supported,
|
||||
prepare_nvfp4_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NvFp4MoeBackend(Enum):
|
||||
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
|
||||
FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
|
||||
FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL"
|
||||
VLLM_CUTLASS = "vLLM CUTASS"
|
||||
MARLIN = "vLLM MARLIN"
|
||||
|
||||
|
||||
FLASHINFER_NVFP4_MOE_BACKENDS = [
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
]
|
||||
|
||||
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
|
||||
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
}
|
||||
|
||||
|
||||
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
|
||||
# Checks whether `backend` supports quantizing with scaling factors
|
||||
# of all experts in Expert Parallel Mode when all experts are not
|
||||
# on the same rank.
|
||||
|
||||
return backend in [
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
]
|
||||
|
||||
|
||||
def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
|
||||
def _make_log_backend(backend: NvFp4MoeBackend):
|
||||
return f"Using {backend.value} backend for NvFp4 MoE"
|
||||
|
||||
if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN:
|
||||
allow_flashinfer = (
|
||||
is_flashinfer_fp4_cutlass_moe_available()
|
||||
or is_flashinfer_fp4_cutedsl_moe_available()
|
||||
)
|
||||
if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
|
||||
else:
|
||||
backend = NvFp4MoeBackend.VLLM_CUTLASS
|
||||
elif is_fp4_marlin_supported():
|
||||
backend = NvFp4MoeBackend.MARLIN
|
||||
else:
|
||||
raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.")
|
||||
|
||||
# Log warning if FI backend requested but not available.
|
||||
if (
|
||||
backend not in FLASHINFER_NVFP4_MOE_BACKENDS
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
):
|
||||
logger.warning_once(
|
||||
"Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
|
||||
"Falling back to %s for NvFp4 MoE",
|
||||
backend.value,
|
||||
scope="local",
|
||||
)
|
||||
else:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend
|
||||
|
||||
|
||||
def convert_to_nvfp4_moe_kernel_format(
|
||||
nvfp4_backend: NvFp4MoeBackend,
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor | None,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a2_scale: torch.Tensor | None,
|
||||
is_act_and_mul: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
if (
|
||||
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
|
||||
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
|
||||
):
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend=nvfp4_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
a13_scale=a13_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
a2_scale=a2_scale,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
|
||||
a13_scale = None
|
||||
a2_scale = None
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
) = prepare_nvfp4_moe_layer_for_marlin(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
|
||||
|
||||
return (
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_quant_config(
|
||||
backend: NvFp4MoeBackend,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
|
||||
if backend in UNSUPPORTED:
|
||||
return None
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
return nvfp4_w4a16_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
g1_alphas = a13_scale * w13_scale_2
|
||||
g2_alphas = a2_scale * w2_scale_2
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
a1_gscale=(1.0 / a13_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel(
|
||||
backend: NvFp4MoeBackend,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> mk.FusedMoEModularKernel | None:
|
||||
assert moe_config.dp_size == 1
|
||||
|
||||
UNSUPPORTED_BACKENDS = [
|
||||
# TRTLLM does not use the modular kernl abstraction.
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
# CUTEDSL is used with BATCHED (masked) format only.
|
||||
# TODO: add here once we support dp/ep via the oracle.
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
]
|
||||
|
||||
if backend in UNSUPPORTED_BACKENDS:
|
||||
return None
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
FlashInferExperts(
|
||||
out_dtype=moe_config.in_dtype,
|
||||
quant_config=quant_config,
|
||||
ep_rank=moe_config.ep_rank,
|
||||
ep_size=moe_config.ep_size,
|
||||
tp_rank=moe_config.tp_rank,
|
||||
tp_size=moe_config.tp_size,
|
||||
use_dp=False,
|
||||
use_deepseek_fp8_block_scale=False,
|
||||
),
|
||||
)
|
||||
|
||||
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
CutlassExpertsFp4(
|
||||
out_dtype=moe_config.in_dtype,
|
||||
# TODO(rob): see what impact this has on expert map?
|
||||
max_experts_per_worker=moe_config.num_experts,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config=quant_config),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFp4 MoE backend: {backend}")
|
||||
@@ -11,7 +11,6 @@ from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
)
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -34,12 +33,8 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
int4_w4afp8_moe_quant_config,
|
||||
int8_w8a8_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
@@ -51,6 +46,15 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
make_fp8_moe_kernel,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
FLASHINFER_NVFP4_MOE_BACKENDS,
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_nvfp4_moe_kernel,
|
||||
make_nvfp4_moe_quant_config,
|
||||
select_nvfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_TYPES_MAP,
|
||||
@@ -58,14 +62,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
prepare_static_weights_for_trtllm_fp4_moe,
|
||||
reorder_w1w3_to_w3w1,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
select_nvfp4_gemm_impl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
@@ -77,20 +76,15 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_bf16_scales_to_fp8,
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -218,31 +212,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support,
|
||||
)
|
||||
if not moe.is_act_and_mul:
|
||||
raise ValueError(
|
||||
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
|
||||
"support non gated MoE models."
|
||||
)
|
||||
|
||||
super().__init__(moe)
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.group_size = 16
|
||||
self.layer_name = layer_name
|
||||
self.marlin_input_dtype = (
|
||||
get_marlin_input_dtype(layer_name) if self.use_marlin else None
|
||||
self.nvfp4_backend = select_nvfp4_moe_backend()
|
||||
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
|
||||
self.nvfp4_backend
|
||||
)
|
||||
self.flashinfer_moe_backend = None
|
||||
if self.allow_flashinfer:
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
" for CompressedTensorsW4A4Nvfp4MoEMethod."
|
||||
)
|
||||
elif self.use_marlin:
|
||||
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
|
||||
else:
|
||||
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -355,7 +337,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# From packed to weight
|
||||
"""
|
||||
Convert NVFP4 MoE weights into kernel format and setup the kernel.
|
||||
"""
|
||||
# NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod
|
||||
# requires this naming convention. However, the name change breaks
|
||||
# reloading because the state dict no longer matches disk. Once we
|
||||
# remove MKM, we should revert this change to ensure compatibility.
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
layer.w13_weight_packed.data, requires_grad=False
|
||||
)
|
||||
@@ -366,144 +354,79 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
delattr(layer, "w2_weight_packed")
|
||||
|
||||
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
|
||||
if self.allow_flashinfer:
|
||||
w, s = reorder_w1w3_to_w3w1(
|
||||
layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)
|
||||
|
||||
if not torch.allclose(
|
||||
# Use a single gscale for w13.
|
||||
if self.moe.is_act_and_mul and not torch.allclose(
|
||||
layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
|
||||
):
|
||||
logger.warning_once(
|
||||
"w1_weight_global_scale must match w3_weight_global_scale. "
|
||||
"Accuracy may be affected."
|
||||
"Accuracy may be affected.",
|
||||
)
|
||||
w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous()
|
||||
|
||||
# Take inverse of global scale saved to disk
|
||||
layer.w13_weight_scale_2 = torch.nn.Parameter(
|
||||
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False
|
||||
# Shuffle weights into the NvFp4 kernel format.
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = convert_to_nvfp4_moe_kernel_format(
|
||||
nvfp4_backend=self.nvfp4_backend,
|
||||
layer=layer,
|
||||
w13=layer.w13_weight,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w13_scale_2=(1.0 / w13_weight_global_scale),
|
||||
a13_scale=(1.0 / layer.w13_input_global_scale),
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_scale_2=(1.0 / layer.w2_weight_global_scale),
|
||||
a2_scale=(1.0 / layer.w2_input_global_scale),
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
|
||||
layer.w2_weight_scale_2 = torch.nn.Parameter(
|
||||
1 / layer.w2_weight_global_scale.data, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_scale)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_scale)
|
||||
layer.w13_weight_scale_2 = w13_scale_2
|
||||
layer.w2_weight_scale_2 = w2_scale_2
|
||||
layer.w13_input_scale = a13_scale
|
||||
layer.w2_input_scale = a2_scale
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
|
||||
return
|
||||
# w13
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
w13_input_global_scale = (
|
||||
layer.w13_input_global_scale.min()
|
||||
.to(torch.float32)
|
||||
.expand(layer.num_experts)
|
||||
)
|
||||
else:
|
||||
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
|
||||
torch.float32
|
||||
)
|
||||
layer.g1_alphas = torch.nn.Parameter(
|
||||
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
layer.w13_input_scale_quant = torch.nn.Parameter(
|
||||
(w13_input_global_scale), requires_grad=False
|
||||
)
|
||||
|
||||
# w2
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
w2_input_global_scale = (
|
||||
layer.w2_input_global_scale.min()
|
||||
.to(torch.float32)
|
||||
.expand(layer.num_experts)
|
||||
)
|
||||
else:
|
||||
w2_input_global_scale = layer.w2_input_global_scale
|
||||
|
||||
layer.g2_alphas = torch.nn.Parameter(
|
||||
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||
(w2_input_global_scale), requires_grad=False
|
||||
)
|
||||
|
||||
# TensorRT-LLM specific processing
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
# Prepare static weights for TRT-LLM kernel
|
||||
# alternate: prepare_static_weight_layouts_for_trtllm_moe
|
||||
(
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
) = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight.size(-2), # hidden_size
|
||||
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||
layer.w13_weight.size(0), # num_experts
|
||||
)
|
||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||
|
||||
layer.w13_weight = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
|
||||
# Additional parameter needed for TRT-LLM
|
||||
layer.g1_scale_c = Parameter(
|
||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
# swizzle weight scales
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
||||
)
|
||||
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
||||
# Initialize the kernel that will be called in apply().
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
use_dp = self.moe.dp_size > 1
|
||||
if self.moe_quant_config is not None and not use_dp:
|
||||
self.kernel = make_nvfp4_moe_kernel(
|
||||
backend=self.nvfp4_backend,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin or (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
|
||||
if self.nvfp4_backend in UNSUPPORTED:
|
||||
return None
|
||||
elif not self.allow_flashinfer:
|
||||
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
# TP case: avoid convert to ModularKernelMethod - to be refactored.
|
||||
if self.moe.dp_size == 1:
|
||||
return None
|
||||
# For now, fp4 moe only works with the flashinfer dispatcher.
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
self.moe
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
@@ -514,7 +437,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
allow_flashinfer=self.allow_flashinfer,
|
||||
allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS),
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
@@ -522,19 +445,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if (
|
||||
self.use_marlin
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return None
|
||||
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w13_scale_2=layer.w13_weight_scale_2,
|
||||
w2_scale_2=layer.w2_weight_scale_2,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -546,14 +464,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
):
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
@@ -566,79 +479,41 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
# Hidden_states in select_experts is only used to extract metadata
|
||||
if isinstance(x, tuple):
|
||||
x_routing, _ = x
|
||||
else:
|
||||
x_routing = x
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
hidden_states=x_routing,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
return fused_marlin_moe(
|
||||
# EPLB path
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
assert layer.enable_eplb
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
# FlashInfer fused experts path
|
||||
elif self.allow_flashinfer:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight
|
||||
), "Flashinfer CUTLASS Fused MoE not applicable!"
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
# only (no EP).
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w2_fp4=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# TODO(bnell): derive these from arguments
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
).to(x.dtype)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
@@ -15,9 +15,7 @@ from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@@ -30,6 +28,15 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
FLASHINFER_NVFP4_MOE_BACKENDS,
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_nvfp4_moe_kernel,
|
||||
make_nvfp4_moe_quant_config,
|
||||
select_nvfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -45,16 +52,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
prepare_static_weights_for_trtllm_fp4_moe,
|
||||
reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
get_flashinfer_moe_backend,
|
||||
is_flashinfer_supporting_global_sf,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -69,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin,
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@@ -89,7 +90,6 @@ from vllm.model_executor.parameter import (
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_scaled_fp4_mm,
|
||||
has_flashinfer,
|
||||
@@ -1327,43 +1327,32 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
quant_config: ModelOptNvFp4Config,
|
||||
layer: FusedMoE,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
|
||||
detect_nvfp4_moe_support, # noqa: E501
|
||||
)
|
||||
|
||||
super().__init__(layer.moe_config)
|
||||
self.quant_config = quant_config
|
||||
self.layer = layer
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.marlin_input_dtype = None
|
||||
self.flashinfer_moe_backend = None
|
||||
if self.allow_flashinfer:
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
" for ModelOptNvFp4FusedMoE."
|
||||
self.nvfp4_backend = select_nvfp4_moe_backend()
|
||||
# TODO: move this type of check into the oracle.
|
||||
if (
|
||||
not self.moe.is_act_and_mul
|
||||
and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Non-gated activations are only supported by FlashInfer "
|
||||
"CUTLASS NvFP4 MoE backend."
|
||||
)
|
||||
elif self.use_marlin:
|
||||
logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
|
||||
else:
|
||||
logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
|
||||
|
||||
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
|
||||
self.nvfp4_backend
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin or (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
|
||||
if self.nvfp4_backend in UNSUPPORTED:
|
||||
return None
|
||||
elif (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
# TP case: avoid convert to ModularKernelMethod - to be refactored.
|
||||
if self.moe.dp_size == 1:
|
||||
return None
|
||||
@@ -1385,7 +1374,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
allow_flashinfer=self.allow_flashinfer,
|
||||
allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
@@ -1405,11 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError(
|
||||
"NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported."
|
||||
)
|
||||
assert self.quant_config.is_checkpoint_nvfp4_serialized
|
||||
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
@@ -1498,14 +1483,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
|
||||
use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
|
||||
self.flashinfer_moe_backend
|
||||
global_sf_num_experts = (
|
||||
global_num_experts if self.use_global_sf else num_experts
|
||||
)
|
||||
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(
|
||||
global_scale_num_experts,
|
||||
global_sf_num_experts,
|
||||
2 if self.moe.is_act_and_mul else 1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -1514,32 +1497,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(global_scale_num_experts, dtype=torch.float32),
|
||||
data=torch.empty(global_sf_num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# GEMM 1 processing
|
||||
gemm1_weight = layer.w13_weight.data
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
"""
|
||||
Convert NVFP4 MoE weights into kernel format and setup the kernel.
|
||||
"""
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
)
|
||||
and self.moe.is_act_and_mul
|
||||
):
|
||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||
gemm1_weight, gemm1_weight_scale, dim=-2
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
|
||||
|
||||
# Common processing for w13_weight_scale_2
|
||||
# Use a single gscale for w13.
|
||||
if self.moe.is_act_and_mul and not torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||
):
|
||||
@@ -1547,136 +1515,47 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"w1_weight_scale_2 must match w3_weight_scale_2. "
|
||||
"Accuracy may be affected."
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||
|
||||
# Common processing for input scales and alphas
|
||||
use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
|
||||
self.flashinfer_moe_backend
|
||||
)
|
||||
if use_global_sf:
|
||||
# For backends provide by Flashinfer, the input global scales are
|
||||
# shared across all experts.
|
||||
w13_input_scale = (
|
||||
layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
|
||||
)
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = convert_to_nvfp4_moe_kernel_format(
|
||||
nvfp4_backend=self.nvfp4_backend,
|
||||
layer=layer,
|
||||
w13=layer.w13_weight,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w13_scale_2=w13_weight_scale_2,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_scale_2=layer.w2_weight_scale_2,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_scale)
|
||||
replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
|
||||
replace_parameter(layer, "w13_input_scale", a13_scale)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_scale)
|
||||
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
|
||||
replace_parameter(layer, "w2_input_scale", a2_scale)
|
||||
|
||||
# GEMM 2 processing
|
||||
if use_global_sf:
|
||||
# For backends provide by Flashinfer, the input global scales are
|
||||
# shared across all experts.
|
||||
w2_input_scale = (
|
||||
layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
|
||||
)
|
||||
else:
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
layer.g2_alphas = Parameter(
|
||||
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
# TensorRT-LLM specific processing
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
# Prepare static weights for TRT-LLM kernel
|
||||
# alternate: prepare_static_weight_layouts_for_trtllm_moe
|
||||
(
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
) = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight.size(-2), # hidden_size
|
||||
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||
layer.w13_weight.size(0), # num_experts
|
||||
)
|
||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||
|
||||
layer.w13_weight = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
|
||||
# Additional parameter needed for TRT-LLM
|
||||
layer.g1_scale_c = Parameter(
|
||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
# Marlin processing
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
else:
|
||||
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
|
||||
w13_weight = layer.w13_weight
|
||||
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
|
||||
if intermediate_size_pad:
|
||||
# padding gated activations will require to split w1 and w3
|
||||
# and pad them individually
|
||||
assert not self.moe.is_act_and_mul, (
|
||||
"The intermediate size required padding, "
|
||||
"but padding is not implemented for gated activations"
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(
|
||||
torch.nn.functional.pad(
|
||||
w13_weight, (0, 0, 0, intermediate_size_pad)
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight = Parameter(
|
||||
torch.nn.functional.pad(
|
||||
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
torch.nn.functional.pad(
|
||||
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
use_dp = self.moe.dp_size > 1
|
||||
if self.moe_quant_config is not None and not use_dp:
|
||||
self.kernel = make_nvfp4_moe_kernel(
|
||||
backend=self.nvfp4_backend,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
@@ -1688,7 +1567,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||
import flashinfer
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
assert self.moe_quant_config is not None
|
||||
a1_gscale = self.moe_quant_config.a1_gscale
|
||||
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||
hidden_states,
|
||||
a1_gscale,
|
||||
@@ -1700,19 +1580,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if (
|
||||
self.use_marlin
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
return None
|
||||
|
||||
return nvfp4_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
w13_scale_2=layer.w13_weight_scale_2,
|
||||
w2_scale_2=layer.w2_weight_scale_2,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1725,18 +1600,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.moe.is_act_and_mul:
|
||||
assert (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
), (
|
||||
"Non-gated activations are only supported by the"
|
||||
" flashinfer CUTLASS backend for modelopt checkpoints"
|
||||
)
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
):
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
@@ -1762,10 +1627,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
# EPLB path
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
assert layer.enable_eplb
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
@@ -1774,81 +1637,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
top_k=layer.top_k,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
return fused_marlin_moe(
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
return self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
elif self.allow_flashinfer:
|
||||
assert self.flashinfer_moe_backend in (
|
||||
FlashinferMoeBackend.CUTLASS,
|
||||
FlashinferMoeBackend.CUTEDSL,
|
||||
)
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
)
|
||||
|
||||
flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
|
||||
flashinfer_cutedsl_moe_fp4,
|
||||
)
|
||||
|
||||
flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
return flashinfer_fn_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
# only (no EP).
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w2_fp4=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# TODO: derive from arguments
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
)
|
||||
|
||||
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -20,12 +23,23 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"is_flashinfer_fp4_cutlass_moe_available",
|
||||
"is_flashinfer_fp4_cutedsl_moe_available",
|
||||
@@ -273,10 +287,9 @@ def flashinfer_trtllm_fp4_moe(
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
@@ -369,10 +382,9 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
@@ -410,3 +422,93 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend: "NvFp4MoeBackend",
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
is_act_and_mul: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
# Delayed import for circular dependency avoidance.
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
)
|
||||
|
||||
assert backend in [
|
||||
NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
]
|
||||
|
||||
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
|
||||
if is_act_and_mul and backend in [
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale)
|
||||
|
||||
# For some FI kernels, the input scales are shared by all experts.
|
||||
if is_global_sf_supported_for_nvfp4_backend(backend):
|
||||
num_experts = w13.shape[0]
|
||||
a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
|
||||
a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)
|
||||
else:
|
||||
a13_scale = a13_scale.max(dim=1).values.to(torch.float32)
|
||||
|
||||
# Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
w2.size(-2), # hidden_size
|
||||
w13.size(-2) // 2, # intermediate_size
|
||||
w13.size(0), # num_experts
|
||||
)
|
||||
|
||||
# We do not need to make this a parameter, because
|
||||
# it is not used during the weight (re)-loading process.
|
||||
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
layer.a1_gscale = 1.0 / a13_scale
|
||||
layer.g1_alphas = a13_scale * w13_scale_2
|
||||
layer.g2_alphas = a2_scale * w2_scale_2
|
||||
else:
|
||||
# Swizzle the block scales for other FI NVFP4 MoE kernels.
|
||||
w13_scale = swizzle_blockscale(w13_scale)
|
||||
|
||||
# Apply padding if needed.
|
||||
pad_size = w13_scale.size(1) - w13.size(1)
|
||||
if pad_size > 0:
|
||||
if is_act_and_mul:
|
||||
raise NotImplementedError(
|
||||
"Intermediate size padding for w1 and w3, for %s "
|
||||
"NvFp4 backend, but this is not currently supported",
|
||||
backend.value,
|
||||
)
|
||||
w13 = torch.nn.functional.pad(w13, (0, 0, 0, pad_size))
|
||||
w2 = torch.nn.functional.pad(w2, (0, pad_size // 2, 0, 0))
|
||||
w2_scale = torch.nn.functional.pad(w2_scale, (0, pad_size // 16))
|
||||
|
||||
w2_scale = swizzle_blockscale(w2_scale)
|
||||
|
||||
return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale
|
||||
|
||||
@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT,
|
||||
get_marlin_input_dtype,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
@@ -226,6 +227,106 @@ def prepare_fp4_layer_for_marlin(
|
||||
return
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_layer_for_marlin(
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
||||
]:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP4 computation but "
|
||||
"FP4 quantization is being used. Weight-only FP4 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads."
|
||||
)
|
||||
|
||||
input_dtype = get_marlin_input_dtype(prefix="")
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
|
||||
|
||||
GROUP_SIZE = 16
|
||||
E = layer.num_experts
|
||||
K = layer.hidden_size
|
||||
N = layer.intermediate_size_per_partition
|
||||
|
||||
device = w13.device
|
||||
param_dtype = layer.params_dtype
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
|
||||
# WORKSPACE
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = N * 2, K
|
||||
else:
|
||||
size_n, size_k = K, N
|
||||
|
||||
assert weight.shape == (E, size_n, size_k // 2)
|
||||
|
||||
for i in range(E):
|
||||
qweight = weight[i].view(torch.int32).T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
tensor_list.append(marlin_qweight)
|
||||
|
||||
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
|
||||
w13 = repack_weight(w13, "w13")
|
||||
w2 = repack_weight(w2, "w2")
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
def premute_scales(
|
||||
scales: torch.Tensor, g_scales: torch.Tensor, name: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
scales = scales.to(param_dtype)
|
||||
g_scales = g_scales.to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = N * 2, K
|
||||
else:
|
||||
size_n, size_k = K, N
|
||||
|
||||
for i in range(E):
|
||||
scale = scales[i].T
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scale,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=GROUP_SIZE,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
g_scales = nvfp4_marlin_process_global_scale(g_scales)
|
||||
return scales, g_scales
|
||||
|
||||
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
|
||||
w2_scale, w2_scale_2 = premute_scales(w2_scale, w2_scale_2, "w2")
|
||||
|
||||
return w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2
|
||||
|
||||
|
||||
def prepare_moe_fp4_layer_for_marlin(
|
||||
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user