[Quant][Feature] Support online MXFP8 quantization for MoE and dense models (#35448)
Signed-off-by: EdalatiAli <aliedalati@cohere.com>
This commit is contained in:
104
tests/models/quantization/test_mxfp8.py
Normal file
104
tests/models/quantization/test_mxfp8.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""E2E tests for online MXFP8 quantization.
|
||||
|
||||
Loads a BF16 model with ``--quantization mxfp8`` (online quantization) and
|
||||
compares log-probabilities against the same model served in BF16 without
|
||||
quantization. This exercises the full pipeline: config parsing,
|
||||
``Mxfp8OnlineLinearMethod``, ``Mxfp8OnlineMoEMethod``, weight loading,
|
||||
online quantization / shuffling, and inference through ``apply_monolithic``.
|
||||
|
||||
Layer skipping (``modules_to_not_convert``) is configured in the model's
|
||||
``config.json`` under ``quantization_config`` and is not tested here.
|
||||
|
||||
``example_prompts`` is a pytest fixture (from conftest.py) that loads 8
|
||||
diverse prompts from ``tests/prompts/example.txt``.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ..utils import check_logprobs_close
|
||||
|
||||
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
|
||||
MOE_MODEL = "Qwen/Qwen3-30B-A3B"
|
||||
# A small dense model (no MoE) to validate the linear-only path.
|
||||
DENSE_MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
MAX_TOKENS = 4
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("mxfp8"),
|
||||
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
|
||||
)
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
|
||||
def test_mxfp8_logprobs(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Compare BF16 baseline logprobs against online MXFP8-quantized model.
|
||||
|
||||
Runs the same model twice -- once in BF16 (baseline) and once with
|
||||
online MXFP8 quantization -- then checks that the top log-probabilities
|
||||
are close. Only 4 tokens are generated to keep the test fast while
|
||||
still catching numerical divergence.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
quantization="mxfp8",
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=test_outputs,
|
||||
name_0="bf16",
|
||||
name_1="mxfp8",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("mxfp8"),
|
||||
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
|
||||
)
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
|
||||
def test_mxfp8_generation(vllm_runner, model: str) -> None:
|
||||
"""Smoke test: verify online MXFP8 model generates coherent text."""
|
||||
prompt = "1 2 3 4 5"
|
||||
with vllm_runner(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
quantization="mxfp8",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
) as vllm_model:
|
||||
output = vllm_model.generate_greedy([prompt], max_tokens=5)
|
||||
|
||||
generated = output[0][1]
|
||||
assert len(generated) > len(prompt), (
|
||||
f"MXFP8 model produced no new tokens. Output: {generated!r}"
|
||||
)
|
||||
@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp8Dynamic,
|
||||
kMxfp8Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -67,11 +69,54 @@ class TrtLlmFp8ExpertsBase:
|
||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) in [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]:
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
@@ -113,9 +158,10 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 block."""
|
||||
"""Supports Fp8 block and MXFP8."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@@ -159,6 +205,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe import Fp8QuantizationType
|
||||
|
||||
# Pack topk_ids and topk_weights into single tensor
|
||||
# Format: (expert_id << 16) | (weight_bf16.view(int16))
|
||||
@@ -175,6 +222,16 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
|
||||
assert a1q_scale is not None
|
||||
|
||||
is_mxfp8 = self.quant_config.block_shape == [1, 32]
|
||||
if is_mxfp8:
|
||||
fp8_quant_type = Fp8QuantizationType.MxFp8
|
||||
use_shuffled_weight = True
|
||||
hidden_states_scale = a1q_scale
|
||||
else:
|
||||
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
|
||||
use_shuffled_weight = False
|
||||
hidden_states_scale = a1q_scale.t().contiguous()
|
||||
|
||||
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
|
||||
# output tensor in-place so we need to manually copy the result to the
|
||||
# output tensor
|
||||
@@ -183,7 +240,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
topk_ids=packed_topk_ids,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
@@ -197,8 +254,9 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
use_shuffled_weight=False,
|
||||
use_shuffled_weight=use_shuffled_weight,
|
||||
weight_layout=0,
|
||||
fp8_quantization_type=fp8_quant_type,
|
||||
# output=output,
|
||||
)
|
||||
output.copy_(result)
|
||||
@@ -240,10 +298,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@@ -256,7 +315,10 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
if (weight_key, activation_key) in [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]:
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
@@ -274,7 +336,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
def _apply_per_block(
|
||||
def _apply_block_scale(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -291,32 +353,38 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe import Fp8QuantizationType
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
assert activation == MoEActivation.SILU
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape in [[128, 128], [1, 32]]
|
||||
# Kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# Kernel requires transposed hidden state scales
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
a1q_scale_t = a1q_scale.t().contiguous()
|
||||
is_mxfp8 = self.quant_config.block_shape == [1, 32]
|
||||
if is_mxfp8:
|
||||
fp8_quant_type = Fp8QuantizationType.MxFp8
|
||||
use_shuffled_weight = True
|
||||
hidden_states_scale = a1q_scale
|
||||
else:
|
||||
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
|
||||
use_shuffled_weight = False
|
||||
hidden_states_scale = a1q_scale.t().contiguous()
|
||||
|
||||
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale_t,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
@@ -330,7 +398,8 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
use_shuffled_weight=use_shuffled_weight,
|
||||
fp8_quantization_type=fp8_quant_type,
|
||||
)
|
||||
|
||||
def _apply_per_tensor(
|
||||
@@ -409,7 +478,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.block_shape is not None:
|
||||
return self._apply_per_block(
|
||||
return self._apply_block_scale(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@@ -441,6 +510,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only per-block and per-tensor quantization are supported in "
|
||||
f"{self.__class__.__name__}."
|
||||
"Only per-block, per-tensor, and MXFP8 quantization are "
|
||||
f"supported in {self.__class__.__name__}."
|
||||
)
|
||||
|
||||
@@ -444,7 +444,7 @@ def convert_to_fp8_moe_kernel_format(
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
|
||||
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_fi(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w2=w2,
|
||||
@@ -512,6 +512,21 @@ def make_fp8_moe_quant_config(
|
||||
g1_alphas=(w1_scale * a1_scale).squeeze(),
|
||||
g2_alphas=(w2_scale * a2_scale).squeeze(),
|
||||
)
|
||||
# MXFP8 uses "mxfp8" quant_dtype so the prepare step dispatches to
|
||||
# _mxfp8_e4m3_quantize rather than standard FP8 block quantization.
|
||||
# Non-swizzled layout is required since the TRTLLM kernel expects
|
||||
# scales in (num_tokens, hidden_dim // 32) format.
|
||||
if block_shape == [1, 32]:
|
||||
return FusedMoEQuantConfig.make(
|
||||
"mxfp8",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
is_nvfp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
|
||||
@@ -1,44 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
backend_to_kernel_cls,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kMxfp8Dynamic,
|
||||
kMxfp8Static,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset(
|
||||
{
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
}
|
||||
)
|
||||
|
||||
class MxFp8MoeBackend(Enum):
|
||||
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
|
||||
_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = {
|
||||
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
}
|
||||
|
||||
|
||||
def _select_kernel_cls(
|
||||
backend: Fp8MoeBackend,
|
||||
config: FusedMoEConfig,
|
||||
) -> type[mk.FusedMoEExperts]:
|
||||
"""Select the first supported expert class for the MXFP8 config."""
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if config.moe_parallel_config.use_batched_activation_format
|
||||
else mk.FusedMoEActivationFormat.Standard
|
||||
)
|
||||
last_reason: str | None = None
|
||||
for cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = cls.is_supported_config(
|
||||
cls,
|
||||
config,
|
||||
kMxfp8Static,
|
||||
kMxfp8Dynamic,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
return cls
|
||||
last_reason = reason
|
||||
raise ValueError(
|
||||
f"No supported MXFP8 expert class for {backend.value}: {last_reason}"
|
||||
)
|
||||
|
||||
|
||||
def select_mxfp8_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
) -> MxFp8MoeBackend:
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
"""Select the MXFP8 MoE backend and the best expert class.
|
||||
|
||||
Returns:
|
||||
A tuple of (fp8_backend, experts_cls).
|
||||
"""
|
||||
if config.is_lora_enabled:
|
||||
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")
|
||||
|
||||
AVAILABLE_BACKENDS = [
|
||||
MxFp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]
|
||||
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
mapping = {
|
||||
"flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
logger.info_once(
|
||||
"Using '%s' MxFp8 MoE backend (user-requested).",
|
||||
backend.value,
|
||||
backend = _BACKEND_NAME_MAP.get(runner_backend)
|
||||
if backend is None:
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for "
|
||||
f"MXFP8 MoE. Expected one of "
|
||||
f"{list(_BACKEND_NAME_MAP.keys())}."
|
||||
)
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
logger.info_once(
|
||||
"Using '%s' MxFp8 MoE backend (user-requested).",
|
||||
backend.value,
|
||||
)
|
||||
return backend, _select_kernel_cls(backend, config)
|
||||
|
||||
# Auto-select: only one backend available for now.
|
||||
backend = AVAILABLE_BACKENDS[0]
|
||||
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
|
||||
return backend
|
||||
# Auto-select: pick the first supported backend.
|
||||
for backend in _SUPPORTED_BACKENDS:
|
||||
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
|
||||
return backend, _select_kernel_cls(backend, config)
|
||||
|
||||
raise ValueError("No MXFP8 MoE backends available.")
|
||||
|
||||
@@ -199,7 +199,7 @@ def _mxfp8_e4m3_quantize(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert A_scale is None
|
||||
assert not per_act_token_quant
|
||||
assert block_shape is None
|
||||
assert block_shape is None or block_shape == [1, 32]
|
||||
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ QuantizationMethods = Literal[
|
||||
"torchao",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"mxfp8",
|
||||
"petit_nvfp4",
|
||||
"cpu_awq",
|
||||
]
|
||||
@@ -129,6 +130,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
)
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .mxfp8 import Mxfp8Config
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .torchao import TorchAOConfig
|
||||
@@ -156,6 +158,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"auto-round": INCConfig,
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"mxfp8": Mxfp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
}
|
||||
|
||||
@@ -25,13 +25,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
|
||||
MxFp8MoeBackend,
|
||||
select_mxfp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
@@ -1712,8 +1712,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.is_checkpoint_mxfp8_serialized
|
||||
|
||||
# Select MXFP8 MoE backend
|
||||
self.mxfp8_backend = select_mxfp8_moe_backend(self.moe)
|
||||
self.mxfp8_backend, _ = select_mxfp8_moe_backend(self.moe)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1943,7 +1942,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
|
||||
return self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
@@ -1956,7 +1955,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
|
||||
Fp8QuantizationType,
|
||||
)
|
||||
|
||||
assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
|
||||
assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
|
||||
354
vllm/model_executor/layers/quantization/mxfp8.py
Normal file
354
vllm/model_executor/layers/quantization/mxfp8.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Online MXFP8 (microscaling FP8, block-32) quantization config and methods."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
|
||||
select_mxfp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8OnlineLinearMethod,
|
||||
Fp8OnlineMoEMethod,
|
||||
_copy_missing_attrs,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
MXFP8_BLOCK_SIZE,
|
||||
Mxfp8LinearBackend,
|
||||
Mxfp8LinearOp,
|
||||
mxfp8_e4m3_quantize,
|
||||
swizzle_mxfp8_scale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
initialize_single_dummy_weight,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Mxfp8Config(Fp8Config):
|
||||
"""Config class for online MXFP8 MoE quantization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: list[str] | None = None,
|
||||
) -> None:
|
||||
if activation_scheme != "dynamic":
|
||||
raise ValueError("mxfp8 only supports dynamic activation scheme.")
|
||||
super().__init__(
|
||||
is_checkpoint_fp8_serialized=False,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers,
|
||||
weight_block_size=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "mxfp8"
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
|
||||
activation_scheme = cls.get_from_keys_or(
|
||||
config, ["activation_scheme"], "dynamic"
|
||||
)
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
if not ignored_layers:
|
||||
ignored_layers = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
return cls(
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return Mxfp8OnlineLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return Mxfp8OnlineMoEMethod(self, layer)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
|
||||
"""Online MXFP8 linear method.
|
||||
Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling
|
||||
FP8 with block-32 scales) during weight loading.
|
||||
|
||||
Args:
|
||||
quant_config: The MXFP8 quantization config.
|
||||
"""
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
def __init__(self, quant_config: "Mxfp8Config"):
|
||||
self.quant_config = quant_config
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.mxfp8_linear = Mxfp8LinearOp(self._select_backend())
|
||||
logger.info_once(
|
||||
"Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_backend() -> Mxfp8LinearBackend:
|
||||
try:
|
||||
from vllm.utils import flashinfer as fi
|
||||
|
||||
_ = fi.mm_mxfp8
|
||||
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"FlashInfer mm_mxfp8 not available, "
|
||||
"falling back to MXFP8 emulation backend."
|
||||
)
|
||||
return Mxfp8LinearBackend.EMULATION
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
|
||||
raise ValueError(
|
||||
f"MXFP8 requires input_size_per_partition "
|
||||
f"({input_size_per_partition}) to be divisible by "
|
||||
f"{MXFP8_BLOCK_SIZE}."
|
||||
)
|
||||
|
||||
super().create_weights(
|
||||
layer,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
input_size,
|
||||
output_size,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
if layer.weight.device == torch.device("meta"):
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=layer.weight.weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
initialize_single_dummy_weight(layer.weight)
|
||||
|
||||
weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())
|
||||
|
||||
if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
|
||||
N, K = layer.weight.shape[0], layer.weight.shape[1]
|
||||
weight_scale = swizzle_mxfp8_scale(weight_scale, N, K)
|
||||
|
||||
layer.input_scale = None
|
||||
replace_parameter(layer, "weight", weight_fp8.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.mxfp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
|
||||
"""MoE method for online MXFP8 (block) quantization."""
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
FusedMoEMethodBase.__init__(self, layer.moe_config)
|
||||
self.quant_config = quant_config
|
||||
assert not quant_config.is_checkpoint_fp8_serialized
|
||||
assert quant_config.activation_scheme == "dynamic"
|
||||
|
||||
self.weight_block_size = [1, MXFP8_BLOCK_SIZE]
|
||||
self.block_quant = True
|
||||
self.weight_scale_name = "weight_scale"
|
||||
|
||||
self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if (
|
||||
hidden_size % MXFP8_BLOCK_SIZE != 0
|
||||
or intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Online MXFP8 MoE requires hidden/intermediate sizes divisible "
|
||||
f"by {MXFP8_BLOCK_SIZE}."
|
||||
)
|
||||
|
||||
super().create_weights(
|
||||
layer=layer,
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // MXFP8_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
layer.weight_block_size = [1, MXFP8_BLOCK_SIZE]
|
||||
|
||||
def _quantize_mxfp8_moe_weight(
|
||||
self, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
|
||||
num_batches = weight.size(0)
|
||||
w_quant = []
|
||||
w_scales = []
|
||||
for i in range(num_batches):
|
||||
mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize(
|
||||
weight[i], is_sf_swizzled_layout=False
|
||||
)
|
||||
w_quant.append(mx_fp8_quant)
|
||||
w_scales.append(mx_fp8_scale)
|
||||
|
||||
return torch.stack(w_quant), torch.stack(w_scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
if layer.w13_weight.device == torch.device("meta"):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
initialize_single_dummy_weight(layer.w13_weight)
|
||||
if layer.w2_weight.device == torch.device("meta"):
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
initialize_single_dummy_weight(layer.w2_weight)
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight)
|
||||
w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight)
|
||||
|
||||
self._setup_kernel(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
@@ -305,6 +305,81 @@ def align_fp8_moe_weights_for_fi(
|
||||
return padded_w13, padded_w2, padded_intermediate
|
||||
|
||||
|
||||
def _shuffle_mxfp8_moe_weights(
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
is_gated: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel.
|
||||
|
||||
Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py:
|
||||
1. reorder_rows_for_gated_act_gemm (interleave gate/up rows)
|
||||
2. shuffle_matrix_a (weight data layout shuffle)
|
||||
3. shuffle_matrix_sf_a (scale factor layout shuffle)
|
||||
"""
|
||||
from flashinfer import (
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
|
||||
epilogue_tile_m = 128
|
||||
num_experts = w13.shape[0]
|
||||
intermediate_size = w13.shape[1] // 2
|
||||
hidden_size = w13.shape[2]
|
||||
|
||||
w13_interleaved: list[torch.Tensor] = []
|
||||
w13_scale_interleaved: list[torch.Tensor] = []
|
||||
for i in range(num_experts):
|
||||
if is_gated:
|
||||
w13_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
w13[i].reshape(2 * intermediate_size, -1)
|
||||
)
|
||||
)
|
||||
w13_scale_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
w13_scale[i].reshape(2 * intermediate_size, -1)
|
||||
)
|
||||
)
|
||||
else:
|
||||
w13_interleaved.append(w13[i])
|
||||
w13_scale_interleaved.append(w13_scale[i])
|
||||
|
||||
w13_shuffled: list[torch.Tensor] = []
|
||||
w2_shuffled: list[torch.Tensor] = []
|
||||
w13_scale_shuffled: list[torch.Tensor] = []
|
||||
w2_scale_shuffled: list[torch.Tensor] = []
|
||||
for i in range(num_experts):
|
||||
w13_shuffled.append(
|
||||
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m))
|
||||
w13_scale_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w13_scale_interleaved[i]
|
||||
.view(torch.uint8)
|
||||
.reshape(2 * intermediate_size, -1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
)
|
||||
w2_scale_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
)
|
||||
|
||||
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
|
||||
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
|
||||
w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape)
|
||||
w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape)
|
||||
|
||||
return w13_out, w2_out, w13_scale_out, w2_scale_out
|
||||
|
||||
|
||||
def prepare_fp8_moe_layer_for_fi(
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
@@ -314,7 +389,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
is_trtllm: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert Fp8 MoE weights to flashinfer kernel format
|
||||
|
||||
@@ -329,10 +404,33 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
block_quant = (
|
||||
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
|
||||
)
|
||||
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
|
||||
is_gated = layer.activation.is_gated
|
||||
|
||||
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
|
||||
if is_mxfp8 and is_trtllm:
|
||||
# FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores
|
||||
# [gate; up]. Swap both weights and scales before interleaving.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
w13 = swap_w13_to_w31(w13)
|
||||
# Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight;
|
||||
# reshape to 3D so swap_w13_to_w31 can flip the two halves,
|
||||
# then flatten back.
|
||||
if w13_scale.ndim == 2:
|
||||
num_rows = w13.shape[1] # 2 * intermediate_size
|
||||
w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1)
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
w13_scale = w13_scale.reshape(w13_scale.shape[0], -1)
|
||||
else:
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights(
|
||||
w13, w2, w13_scale, w2_scale, is_gated
|
||||
)
|
||||
return w13, w2, w13_scale, w2_scale
|
||||
|
||||
# Some FI MoE kernels require internal alignment of 16
|
||||
# for the gate-up proj. Pad the weights to respect this.
|
||||
is_gated = layer.activation.is_gated
|
||||
if not block_quant:
|
||||
min_alignment = 16 if is_gated else 128
|
||||
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
|
||||
@@ -369,4 +467,4 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
|
||||
w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
|
||||
|
||||
return w13, w2, w13_scale
|
||||
return w13, w2, w13_scale, w2_scale
|
||||
|
||||
@@ -149,6 +149,12 @@ kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
|
||||
kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128))
|
||||
kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True)
|
||||
|
||||
kMxfp8StaticScale = ScaleDesc(torch.uint8, True, GroupShape(1, 32))
|
||||
kMxfp8Static = QuantKey(FP8_DTYPE, kMxfp8StaticScale, symmetric=True)
|
||||
|
||||
kMxfp8DynamicScale = ScaleDesc(torch.uint8, False, GroupShape(1, 32))
|
||||
kMxfp8Dynamic = QuantKey(FP8_DTYPE, kMxfp8DynamicScale, symmetric=True)
|
||||
|
||||
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
|
||||
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user