[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -27,13 +26,17 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -46,25 +49,20 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
get_flashinfer_moe_backend,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
deepgemm_post_process_fp8_weight_block,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@@ -73,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@@ -81,12 +78,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
all_close_1d,
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
@@ -96,11 +91,8 @@ from vllm.model_executor.parameter import (
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -110,107 +102,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Fp8MoeBackend(Enum):
|
||||
NONE = 0
|
||||
FLASHINFER_TRTLLM = 1
|
||||
FLASHINFER_CUTLASS = 2
|
||||
DEEPGEMM = 3
|
||||
MARLIN = 4
|
||||
TRITON = 5
|
||||
AITER = 6
|
||||
|
||||
|
||||
def get_fp8_moe_backend(
|
||||
block_quant: bool,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
with_lora_support: bool,
|
||||
) -> Fp8MoeBackend | None:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
if current_platform.is_xpu():
|
||||
return None
|
||||
if with_lora_support:
|
||||
return Fp8MoeBackend.TRITON
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
):
|
||||
backend = get_flashinfer_moe_backend()
|
||||
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant and current_platform.is_device_capability_family(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization on SM100. Please use "
|
||||
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"instead."
|
||||
)
|
||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||
return Fp8MoeBackend.MARLIN
|
||||
|
||||
# Determine if we should use DeepGEMM with block-quantized weights:
|
||||
# - If explicitly set by user, respect their choice
|
||||
# - If not explicitly set (default), disable when TP size is >= 8
|
||||
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
|
||||
moe_use_deep_gemm = False
|
||||
logger.info_once(
|
||||
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
|
||||
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
# Determine if we should use DeepGEMM (top-level enable switch)
|
||||
# - If explicitly set by user, respect their choice
|
||||
# - If not platform supports DeepGEMM, disable it
|
||||
# This helps avoid warning messages on unsupported platforms.
|
||||
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
|
||||
if not is_deep_gemm_supported():
|
||||
use_deep_gemm = False
|
||||
logger.info_once(
|
||||
"DeepGEMM is disabled because the platform does not support it.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
if use_deep_gemm and moe_use_deep_gemm and block_quant:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once(
|
||||
"DeepGEMM backend requested but not available.", scope="local"
|
||||
)
|
||||
elif is_deep_gemm_supported():
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
|
||||
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.AITER
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@@ -348,7 +239,6 @@ class Fp8Config(QuantizationConfig):
|
||||
moe_quant_method = Fp8MoEMethod(self, layer)
|
||||
else:
|
||||
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
|
||||
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
@@ -736,40 +626,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
self.weight_scale_name = (
|
||||
"weight_scale_inv" if self.block_quant else "weight_scale"
|
||||
)
|
||||
self.fp8_backend = get_fp8_moe_backend(
|
||||
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
|
||||
self.fp8_backend = select_fp8_moe_backend(
|
||||
block_quant=self.block_quant,
|
||||
tp_size=layer.moe_parallel_config.tp_size,
|
||||
with_lora_support=self.moe.is_lora_enabled,
|
||||
)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
if self.block_quant and self.weight_block_size != [128, 128]:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports block "
|
||||
"size [128, 128]."
|
||||
)
|
||||
if not self.block_quant:
|
||||
if layer.renormalize or layer.custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
|
||||
f"function or renormalization, but got {layer.renormalize} and "
|
||||
f"{layer.custom_routing_function}."
|
||||
)
|
||||
if layer.scoring_func != "sigmoid":
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports "
|
||||
f"'sigmoid' scoring function, but got {layer.scoring_func}."
|
||||
)
|
||||
if layer.activation != "silu":
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
|
||||
@@ -778,12 +652,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
dynamic_per_token = (
|
||||
not self.block_quant and self.quant_config.activation_scheme != "static"
|
||||
)
|
||||
if self.flashinfer_moe_backend is not None and dynamic_per_token:
|
||||
if dynamic_per_token and self.fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer FP8 MoE backend does not support dynamic per token "
|
||||
"activation quantization."
|
||||
)
|
||||
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
@@ -907,148 +786,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def _convert_weights_to_kernel_format(
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: Module,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
) -> None:
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
assert self.block_quant
|
||||
w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=w13_weight,
|
||||
ws=w13_weight_scale,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=w2_weight,
|
||||
ws=w2_weight_scale,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
(
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
) = prepare_moe_fp8_layer_for_marlin(
|
||||
layer,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
layer.workspace = workspace
|
||||
|
||||
elif self.fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
w13_weight = swap_w13_to_w31(w13_weight)
|
||||
if self.block_quant:
|
||||
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
|
||||
else:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
w13_weight_scale=w13_weight,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_weight_scale=w2_weight,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
)
|
||||
# Shuffle weights to runtime format.
|
||||
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
|
||||
fp8_backend=self.fp8_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w2=w2,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Replace parameters with updated versions. Note that this helper
|
||||
# function ensures the replacement is compatible with RL weight reloads.
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
|
||||
|
||||
def _setup_kernel(self, layer: Module) -> None:
|
||||
"""Setup Modular Kernel for TP Case"""
|
||||
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||
# all of the kernels in the TP case to use mk. Once this is
|
||||
# done, then we will initialzie the TP case and DP/EP case
|
||||
# via the same code path (i.e. via maybe_init_modular_kernel).
|
||||
# NOTE(rob): in progress migrating all into this format.
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
# Flashinfer TRTLLM does not use the modular kernel abstraction.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
self.use_inplace = True
|
||||
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO: make defer_input_quant an attr of the FlashInferExperts
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
|
||||
FlashInferExperts(
|
||||
out_dtype=layer.orig_dtype,
|
||||
quant_config=self.moe_quant_config,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
tp_size=self.moe.tp_size,
|
||||
use_dp=(self.moe.dp_size > 1),
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
),
|
||||
)
|
||||
self.use_inplace = False
|
||||
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO: make defer_input_quant an attr of the AiterExperts
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
AiterExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
else:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
),
|
||||
if self.moe_quant_config:
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
layer=layer,
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
@@ -1056,78 +830,58 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
return
|
||||
|
||||
# Allow for accessing weights and scales in standard way.
|
||||
w13_weight = layer.w13_weight
|
||||
w2_weight = layer.w2_weight
|
||||
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
w13_input_scale = layer.w13_input_scale
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
|
||||
if current_platform.is_fp8_fnuz():
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
w13_weight, w13_weight_scale, w13_input_scale
|
||||
)
|
||||
w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_input_scale,
|
||||
)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w2_weight, w2_weight_scale, w2_input_scale
|
||||
w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_input_scale,
|
||||
)
|
||||
|
||||
# Per tensor kernels require single activation scale. Use the max.
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
assert not self.block_quant
|
||||
assert w13_input_scale is not None and w2_input_scale is not None
|
||||
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
|
||||
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
|
||||
w13_input_scale, w2_input_scale
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||
|
||||
# Per tensor kernels require single weight scale for w13 per expert, but
|
||||
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
||||
if not self.block_quant:
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
w13_weight[expert_id][start : start + shard_size, :],
|
||||
w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
start += shard_size
|
||||
w13_weight_scale = max_w13_scales
|
||||
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
|
||||
w13, w13_scale, shard_size, layer.local_num_experts
|
||||
)
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
# Shuffle weights to runtime format and setup kernel.
|
||||
self._setup_kernel(
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
self._setup_kernel(layer)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.fp8_backend == Fp8MoeBackend.AITER
|
||||
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
if self.fp8_backend in [
|
||||
Fp8MoeBackend.AITER,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
return None
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
@@ -1184,7 +938,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
elif self.moe.is_lora_enabled:
|
||||
return TritonExperts(quant_config=self.moe_quant_config)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# Select GEMM experts with block-scale when weights are block-quantized
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
self.moe,
|
||||
@@ -1193,17 +947,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
else:
|
||||
elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
logger.debug(
|
||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
return TritonOrDeepGemmExperts(self.moe_quant_config)
|
||||
else:
|
||||
assert self.fp8_backend == Fp8MoeBackend.TRITON
|
||||
logger.debug(
|
||||
"TritonExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -1212,42 +972,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN uses mixed precision W8A16 config.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
# Flashinfer CUTLASS per-tensor uses single dq scale
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if (
|
||||
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
and not self.block_quant
|
||||
):
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
)
|
||||
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
return make_fp8_moe_quant_config(
|
||||
fp8_backend=self.fp8_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
@@ -1269,7 +1000,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
# TODO(rob): convert this to MK.
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
@@ -1308,10 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not layer.renormalize and layer.custom_routing_function is not None
|
||||
)
|
||||
result = apply_flashinfer_per_tensor_scale_fp8(
|
||||
result = apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -1327,6 +1055,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
assert self.kernel is not None
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -1358,7 +1088,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
assert not quant_config.is_checkpoint_fp8_serialized
|
||||
assert quant_config.activation_scheme == "dynamic"
|
||||
assert quant_config.weight_block_size is None
|
||||
assert self.flashinfer_moe_backend is None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1447,6 +1176,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
@@ -1457,33 +1188,30 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
|
||||
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
|
||||
layer.w13_weight[expert, :, :]
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
|
||||
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight[expert, :, :]
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
w13_input_scale=None,
|
||||
w2_input_scale=None,
|
||||
# Shuffle weights to runtime format and setup kernel.
|
||||
self._setup_kernel(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
self._setup_kernel(layer)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user