[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)

This commit is contained in:
Robert Shaw
2026-01-07 19:42:33 -05:00
committed by GitHub
parent ffc0a2798b
commit 5dcd7ef1f2
38 changed files with 1439 additions and 1528 deletions

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
import torch
@@ -27,13 +26,17 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -46,25 +49,20 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
deepgemm_post_process_fp8_weight_block,
maybe_post_process_fp8_weight_block,
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
process_fp8_weight_tensor_strategy_moe,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -73,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -81,12 +78,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
all_close_1d,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
@@ -96,11 +91,8 @@ from vllm.model_executor.parameter import (
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -110,107 +102,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
class Fp8MoeBackend(Enum):
NONE = 0
FLASHINFER_TRTLLM = 1
FLASHINFER_CUTLASS = 2
DEEPGEMM = 3
MARLIN = 4
TRITON = 5
AITER = 6
def get_fp8_moe_backend(
block_quant: bool,
moe_parallel_config: FusedMoEParallelConfig,
with_lora_support: bool,
) -> Fp8MoeBackend | None:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if current_platform.is_xpu():
return None
if with_lora_support:
return Fp8MoeBackend.TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability_family(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
if current_platform.is_rocm():
use_marlin = False
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
moe_use_deep_gemm = False
logger.info_once(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
scope="local",
)
# Determine if we should use DeepGEMM (top-level enable switch)
# - If explicitly set by user, respect their choice
# - If not platform supports DeepGEMM, disable it
# This helps avoid warning messages on unsupported platforms.
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
if not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
if use_deep_gemm and moe_use_deep_gemm and block_quant:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local"
)
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
return Fp8MoeBackend.AITER
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@@ -348,7 +239,6 @@ class Fp8Config(QuantizationConfig):
moe_quant_method = Fp8MoEMethod(self, layer)
else:
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
@@ -736,40 +626,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
self.fp8_backend = get_fp8_moe_backend(
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
self.fp8_backend = select_fp8_moe_backend(
block_quant=self.block_quant,
tp_size=layer.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
)
self.marlin_input_dtype = None
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
if self.block_quant and self.weight_block_size != [128, 128]:
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports block "
"size [128, 128]."
)
if not self.block_quant:
if layer.renormalize or layer.custom_routing_function is not None:
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
f"function or renormalization, but got {layer.renormalize} and "
f"{layer.custom_routing_function}."
)
if layer.scoring_func != "sigmoid":
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports "
f"'sigmoid' scoring function, but got {layer.scoring_func}."
)
if layer.activation != "silu":
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
@@ -778,12 +652,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
if self.flashinfer_moe_backend is not None and dynamic_per_token:
if dynamic_per_token and self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
layer: Module,
@@ -907,148 +786,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None
layer.w2_input_scale = None
def _convert_weights_to_kernel_format(
def _setup_kernel(
self,
layer: Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert self.block_quant
w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=w13_weight,
ws=w13_weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=w2_weight,
ws=w2_weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
(
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
input_dtype=self.marlin_input_dtype,
)
layer.workspace = workspace
elif self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13_weight = swap_w13_to_w31(w13_weight)
if self.block_quant:
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
else:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=w13_weight,
w13_input_scale=w13_input_scale,
w2_weight_scale=w2_weight,
w2_input_scale=w2_input_scale,
)
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
# Shuffle weights to runtime format.
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
def _setup_kernel(self, layer: Module) -> None:
"""Setup Modular Kernel for TP Case"""
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
# Flashinfer TRTLLM does not use the modular kernel abstraction.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
# Setup modular kernel for TP case.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.use_inplace = True
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the FlashInferExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
FlashInferExperts(
out_dtype=layer.orig_dtype,
quant_config=self.moe_quant_config,
ep_rank=self.moe.ep_rank,
ep_size=self.moe.ep_size,
tp_rank=self.moe.tp_rank,
tp_size=self.moe.tp_size,
use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant,
),
)
self.use_inplace = False
elif self.fp8_backend == Fp8MoeBackend.AITER:
self.kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=self.moe_quant_config),
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=self.moe_quant_config),
)
else:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
if self.moe_quant_config:
self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
)
def process_weights_after_loading(self, layer: Module) -> None:
@@ -1056,78 +830,58 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return
# Allow for accessing weights and scales in standard way.
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
w13_weight, w13_weight_scale, w13_input_scale
)
w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w13,
w13_scale,
w13_input_scale,
)
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w2_weight, w2_weight_scale, w2_input_scale
w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w2,
w2_scale,
w2_input_scale,
)
# Per tensor kernels require single activation scale. Use the max.
if self.quant_config.activation_scheme == "static":
assert not self.block_quant
assert w13_input_scale is not None and w2_input_scale is not None
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
w13_input_scale, w2_input_scale
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
if not self.block_quant:
shard_size = layer.intermediate_size_per_partition
max_w13_scales = w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
w13_weight[expert_id][start : start + shard_size, :],
w13_weight_scale[expert_id][shard_id],
)
w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
w13_weight_scale = max_w13_scales
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
w13, w13_scale, shard_size, layer.local_num_experts
)
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
# Setup modular kernel for TP case.
self._setup_kernel(layer)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.fp8_backend == Fp8MoeBackend.AITER
or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
if self.fp8_backend in [
Fp8MoeBackend.AITER,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
@@ -1184,7 +938,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
elif self.moe.is_lora_enabled:
return TritonExperts(quant_config=self.moe_quant_config)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(
self.moe,
@@ -1193,17 +947,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
else:
elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__,
self.weight_block_size,
False,
)
return TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
return TritonOrDeepGemmExperts(self.moe_quant_config)
else:
assert self.fp8_backend == Fp8MoeBackend.TRITON
logger.debug(
"TritonExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__,
self.weight_block_size,
False,
)
return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -1212,42 +972,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN uses mixed precision W8A16 config.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
block_shape=self.weight_block_size,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if (
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
and not self.block_quant
):
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
@@ -1269,7 +1000,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
@@ -1308,10 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling=layer.routed_scaling_factor,
)
else:
assert (
not layer.renormalize and layer.custom_routing_function is not None
)
result = apply_flashinfer_per_tensor_scale_fp8(
result = apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
@@ -1327,6 +1055,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
hidden_states=x,
router_logits=router_logits,
)
assert self.kernel is not None
result = self.kernel(
x,
layer.w13_weight,
@@ -1358,7 +1088,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
assert quant_config.weight_block_size is None
assert self.flashinfer_moe_backend is None
def create_weights(
self,
@@ -1447,6 +1176,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
@@ -1457,33 +1188,30 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
layer.w13_weight[expert, :, :]
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
layer.w2_weight[expert, :, :]
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_scale=layer.w2_weight_scale,
w13_input_scale=None,
w2_input_scale=None,
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
)
# Setup modular kernel for TP case.
self._setup_kernel(layer)
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""