Add support for ModelOpt MXFP8 MoE models (#35986)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -20,6 +20,8 @@ TRTLLM_GEN_MXFP4_AVAILABLE = (
|
|||||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TRTLLM_GEN_MXFP8_AVAILABLE = TRTLLM_GEN_MXFP4_AVAILABLE
|
||||||
|
|
||||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(90)
|
and current_platform.is_device_capability(90)
|
||||||
@@ -34,9 +36,15 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
|
|||||||
shuffle_matrix_a,
|
shuffle_matrix_a,
|
||||||
shuffle_matrix_sf_a,
|
shuffle_matrix_sf_a,
|
||||||
trtllm_fp4_block_scale_moe,
|
trtllm_fp4_block_scale_moe,
|
||||||
|
trtllm_fp8_block_scale_moe,
|
||||||
)
|
)
|
||||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||||
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
|
||||||
|
if TRTLLM_GEN_MXFP8_AVAILABLE:
|
||||||
|
from flashinfer.fused_moe.core import (
|
||||||
|
Fp8QuantizationType,
|
||||||
|
get_w2_permute_indices_with_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -160,6 +168,7 @@ def reference_moe(
|
|||||||
beta,
|
beta,
|
||||||
limit,
|
limit,
|
||||||
act_type,
|
act_type,
|
||||||
|
is_gated,
|
||||||
):
|
):
|
||||||
# renormalize routing
|
# renormalize routing
|
||||||
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
|
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
|
||||||
@@ -170,7 +179,12 @@ def reference_moe(
|
|||||||
mlp1_weight = w13[expert_indices, ...]
|
mlp1_weight = w13[expert_indices, ...]
|
||||||
mlp1_bias = bias13[expert_indices, ...]
|
mlp1_bias = bias13[expert_indices, ...]
|
||||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||||
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
if is_gated:
|
||||||
|
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
||||||
|
else:
|
||||||
|
# RELU2_NO_MUL: relu(x)^2
|
||||||
|
t = torch.relu(t)
|
||||||
|
t = t * t
|
||||||
|
|
||||||
if act_type == "mxfp8":
|
if act_type == "mxfp8":
|
||||||
t_quantized, t_scale = mxfp8_quantize(
|
t_quantized, t_scale = mxfp8_quantize(
|
||||||
@@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
|||||||
beta,
|
beta,
|
||||||
limit,
|
limit,
|
||||||
act_type,
|
act_type,
|
||||||
|
is_gated=True,
|
||||||
)
|
)
|
||||||
ref_result[start_idx:end_idx].copy_(chunk_result)
|
ref_result[start_idx:end_idx].copy_(chunk_result)
|
||||||
|
|
||||||
@@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
|||||||
beta,
|
beta,
|
||||||
limit,
|
limit,
|
||||||
"bf16",
|
"bf16",
|
||||||
|
is_gated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||||
@@ -890,6 +906,7 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
|||||||
beta,
|
beta,
|
||||||
limit,
|
limit,
|
||||||
"mxfp8",
|
"mxfp8",
|
||||||
|
is_gated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
||||||
@@ -965,3 +982,169 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
|||||||
|
|
||||||
# Allow some mismatch due to MXFP4 quantization
|
# Allow some mismatch due to MXFP4 quantization
|
||||||
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("topk", [1, 4])
|
||||||
|
@pytest.mark.parametrize("num_experts", [32])
|
||||||
|
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||||
|
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||||
|
@pytest.mark.parametrize("is_gated", [True], ids=["gated"])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not TRTLLM_GEN_MXFP8_AVAILABLE,
|
||||||
|
reason="nvidia gpu and compute capability sm100 is required for this test",
|
||||||
|
)
|
||||||
|
def test_trtllm_gen_mxfp8_block_scale_moe(
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
num_tokens: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
is_gated: bool,
|
||||||
|
):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
inter_size = intermediate_size * (2 if is_gated else 1)
|
||||||
|
|
||||||
|
hidden_states = (
|
||||||
|
torch.randn(num_tokens, hidden_size, device=device, dtype=torch.bfloat16) / 20
|
||||||
|
)
|
||||||
|
w13 = (
|
||||||
|
torch.randn(
|
||||||
|
num_experts,
|
||||||
|
inter_size,
|
||||||
|
hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
/ 20
|
||||||
|
)
|
||||||
|
w2 = (
|
||||||
|
torch.randn(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
/ 20
|
||||||
|
)
|
||||||
|
router_logits = torch.rand(
|
||||||
|
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
router_logits_kernel = router_logits.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# Quantize weights to MXFP8 and normalize scales to [E, M, K//32].
|
||||||
|
w13_q, w13_scale = mxfp8_quantize(w13, is_sf_swizzled_layout=False)
|
||||||
|
w2_q, w2_scale = mxfp8_quantize(w2, is_sf_swizzled_layout=False)
|
||||||
|
if w13_scale.ndim == 1:
|
||||||
|
w13_scale = w13_scale.view(
|
||||||
|
num_experts,
|
||||||
|
inter_size,
|
||||||
|
hidden_size // 32,
|
||||||
|
)
|
||||||
|
if w2_scale.ndim == 1:
|
||||||
|
w2_scale = w2_scale.view(num_experts, hidden_size, intermediate_size // 32)
|
||||||
|
|
||||||
|
# Quantize activations to MXFP8.
|
||||||
|
hidden_states_q, hidden_states_scale = mxfp8_quantize(
|
||||||
|
hidden_states, is_sf_swizzled_layout=False
|
||||||
|
)
|
||||||
|
if hidden_states_scale.ndim == 1:
|
||||||
|
hidden_states_scale = hidden_states_scale.view(num_tokens, hidden_size // 32)
|
||||||
|
|
||||||
|
# Reference output using dequantized tensors + MXFP8 intermediate quantization.
|
||||||
|
w13_ref = mxfp8_dequantize(w13_q, w13_scale).to(torch.float32)
|
||||||
|
w2_ref = mxfp8_dequantize(w2_q, w2_scale).to(torch.float32)
|
||||||
|
hidden_states_ref = mxfp8_dequantize(hidden_states_q, hidden_states_scale).to(
|
||||||
|
torch.float32
|
||||||
|
)
|
||||||
|
bias13 = torch.zeros(
|
||||||
|
num_experts,
|
||||||
|
intermediate_size * (2 if is_gated else 1),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
bias2 = torch.zeros(num_experts, hidden_size, device=device)
|
||||||
|
ref = reference_moe(
|
||||||
|
router_logits_kernel.to(torch.float32),
|
||||||
|
topk,
|
||||||
|
num_experts,
|
||||||
|
hidden_states_ref,
|
||||||
|
w13_ref,
|
||||||
|
bias13,
|
||||||
|
w2_ref,
|
||||||
|
bias2,
|
||||||
|
alpha=1.0,
|
||||||
|
beta=0.0,
|
||||||
|
limit=None,
|
||||||
|
act_type="mxfp8",
|
||||||
|
is_gated=is_gated,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shuffle weights/scales with the same indexed layout used by TRTLLM kernels.
|
||||||
|
epilogue_tile_m = 128
|
||||||
|
gemm1_weights_shuffled = []
|
||||||
|
gemm1_scales_shuffled = []
|
||||||
|
gemm2_weights_shuffled = []
|
||||||
|
gemm2_scales_shuffled = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
w13_rows = intermediate_size * (2 if is_gated else 1)
|
||||||
|
w13_interleaved = w13_q[i].clone().reshape(w13_rows, -1)
|
||||||
|
w13_scale_interleaved = w13_scale[i].clone().reshape(w13_rows, -1)
|
||||||
|
if is_gated:
|
||||||
|
w13_interleaved = reorder_rows_for_gated_act_gemm(w13_interleaved)
|
||||||
|
w13_scale_interleaved = reorder_rows_for_gated_act_gemm(
|
||||||
|
w13_scale_interleaved
|
||||||
|
)
|
||||||
|
gemm1_weights_shuffled.append(
|
||||||
|
shuffle_matrix_a(w13_interleaved.view(torch.uint8), epilogue_tile_m)
|
||||||
|
.contiguous()
|
||||||
|
.view(w13_q.dtype)
|
||||||
|
)
|
||||||
|
gemm2_weights_shuffled.append(
|
||||||
|
shuffle_matrix_a(w2_q[i].view(torch.uint8), epilogue_tile_m)
|
||||||
|
.contiguous()
|
||||||
|
.view(w2_q.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
gemm1_scales_shuffled.append(
|
||||||
|
shuffle_matrix_sf_a(
|
||||||
|
w13_scale_interleaved.view(torch.uint8).reshape(w13_rows, -1),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
|
.contiguous()
|
||||||
|
.view(w13_scale.dtype)
|
||||||
|
)
|
||||||
|
gemm2_scales_shuffled.append(
|
||||||
|
shuffle_matrix_sf_a(
|
||||||
|
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), epilogue_tile_m
|
||||||
|
)
|
||||||
|
.contiguous()
|
||||||
|
.view(w2_scale.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = trtllm_fp8_block_scale_moe(
|
||||||
|
routing_logits=router_logits_kernel,
|
||||||
|
routing_bias=None,
|
||||||
|
hidden_states=hidden_states_q,
|
||||||
|
hidden_states_scale=hidden_states_scale,
|
||||||
|
gemm1_weights=torch.stack(gemm1_weights_shuffled),
|
||||||
|
gemm1_weights_scale=torch.stack(gemm1_scales_shuffled),
|
||||||
|
gemm2_weights=torch.stack(gemm2_weights_shuffled),
|
||||||
|
gemm2_weights_scale=torch.stack(gemm2_scales_shuffled),
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=topk,
|
||||||
|
n_group=None,
|
||||||
|
topk_group=None,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
local_expert_offset=0,
|
||||||
|
local_num_experts=num_experts,
|
||||||
|
routed_scaling_factor=None,
|
||||||
|
routing_method_type=1, # renormalize routing
|
||||||
|
use_shuffled_weight=True,
|
||||||
|
weight_layout=0, # MajorK
|
||||||
|
fp8_quantization_type=Fp8QuantizationType.MxFp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Block-scale MXFP8 kernels are approximate; require majority close.
|
||||||
|
check_accuracy(ref, out, atol=0.1, rtol=0.85, percent=0.8)
|
||||||
|
|||||||
@@ -1204,17 +1204,26 @@ class FusedMoE(CustomOp):
|
|||||||
# Determine per-tensor weight scale patterns based on variant
|
# Determine per-tensor weight scale patterns based on variant
|
||||||
# Use the dedicated method instead of brittle string matching
|
# Use the dedicated method instead of brittle string matching
|
||||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern()
|
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern()
|
||||||
|
quant_method = getattr(param, "quant_method", None)
|
||||||
|
|
||||||
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
||||||
# weights scales.
|
# weights scales.
|
||||||
# Input scales are always per-tensor.
|
# Input scales are always per-tensor.
|
||||||
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
||||||
# "weight_scale" for per-tensor scales.
|
# "weight_scale" for per-tensor scales.
|
||||||
|
# NOTE: ModelOpt MXFP8 MoE uses block scales in weight_scale
|
||||||
|
# tensors (quant_method=BLOCK), so those must not be treated
|
||||||
|
# as per-tensor scalars here.
|
||||||
|
is_block_weight_scale = (
|
||||||
|
"weight_scale" in weight_name
|
||||||
|
and quant_method == FusedMoeWeightScaleSupported.BLOCK.value
|
||||||
|
)
|
||||||
is_per_tensor = (
|
is_per_tensor = (
|
||||||
"weight_scale_2" in weight_name
|
"weight_scale_2" in weight_name
|
||||||
if uses_weight_scale_2
|
if uses_weight_scale_2
|
||||||
else "weight_scale" in weight_name
|
else "weight_scale" in weight_name
|
||||||
) or "input_scale" in weight_name
|
) or "input_scale" in weight_name
|
||||||
|
is_per_tensor = is_per_tensor and not is_block_weight_scale
|
||||||
if is_per_tensor:
|
if is_per_tensor:
|
||||||
self._load_per_tensor_weight_scale(
|
self._load_per_tensor_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
|
|||||||
44
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
Normal file
44
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MxFp8MoeBackend(Enum):
|
||||||
|
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
|
||||||
|
|
||||||
|
|
||||||
|
def select_mxfp8_moe_backend(
|
||||||
|
config: FusedMoEConfig,
|
||||||
|
) -> MxFp8MoeBackend:
|
||||||
|
if config.is_lora_enabled:
|
||||||
|
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")
|
||||||
|
|
||||||
|
AVAILABLE_BACKENDS = [
|
||||||
|
MxFp8MoeBackend.FLASHINFER_TRTLLM,
|
||||||
|
]
|
||||||
|
|
||||||
|
runner_backend = config.moe_backend
|
||||||
|
if runner_backend != "auto":
|
||||||
|
mapping = {
|
||||||
|
"flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM,
|
||||||
|
}
|
||||||
|
if backend := mapping.get(runner_backend):
|
||||||
|
logger.info_once(
|
||||||
|
"Using '%s' MxFp8 MoE backend (user-requested).",
|
||||||
|
backend.value,
|
||||||
|
)
|
||||||
|
return backend
|
||||||
|
raise ValueError(
|
||||||
|
f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. "
|
||||||
|
f"Expected one of {list(mapping.keys())}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-select: only one backend available for now.
|
||||||
|
backend = AVAILABLE_BACKENDS[0]
|
||||||
|
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
|
||||||
|
return backend
|
||||||
@@ -9,17 +9,19 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
||||||
init_fp8_linear_kernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||||
|
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
|
RoutingMethodType,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE,
|
FusedMoE,
|
||||||
FusedMoEMethodBase,
|
|
||||||
FusedMoeWeightScaleSupported,
|
FusedMoeWeightScaleSupported,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||||
@@ -28,6 +30,10 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
|||||||
make_fp8_moe_quant_config,
|
make_fp8_moe_quant_config,
|
||||||
select_fp8_moe_backend,
|
select_fp8_moe_backend,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
|
||||||
|
MxFp8MoeBackend,
|
||||||
|
select_mxfp8_moe_backend,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||||
convert_to_nvfp4_moe_kernel_format,
|
convert_to_nvfp4_moe_kernel_format,
|
||||||
is_global_sf_supported_for_nvfp4_backend,
|
is_global_sf_supported_for_nvfp4_backend,
|
||||||
@@ -46,6 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
|
swap_w13_to_w31,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
W8A8BlockFp8LinearOp,
|
||||||
process_fp8_input_tensor_strategy_moe,
|
process_fp8_input_tensor_strategy_moe,
|
||||||
@@ -60,6 +69,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
|||||||
MXFP8_VALUE_DTYPE,
|
MXFP8_VALUE_DTYPE,
|
||||||
Mxfp8LinearBackend,
|
Mxfp8LinearBackend,
|
||||||
Mxfp8LinearOp,
|
Mxfp8LinearOp,
|
||||||
|
mxfp8_e4m3_quantize,
|
||||||
swizzle_mxfp8_scale,
|
swizzle_mxfp8_scale,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||||
@@ -86,7 +96,8 @@ from vllm.model_executor.parameter import (
|
|||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import replace_parameter
|
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||||
|
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
@@ -1487,17 +1498,6 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
|||||||
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
|
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
|
||||||
return 100
|
return 100
|
||||||
|
|
||||||
def get_quant_method(
|
|
||||||
self, layer: torch.nn.Module, prefix: str
|
|
||||||
) -> "QuantizeMethodBase | None":
|
|
||||||
# MXFP8 does not yet support MoE models
|
|
||||||
if isinstance(layer, FusedMoE):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MXFP8 quantization does not yet support MoE models. "
|
|
||||||
"Please use FP8 or NVFP4 quantization for MoE models."
|
|
||||||
)
|
|
||||||
return super().get_quant_method(layer, prefix)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(
|
def override_quantization_method(
|
||||||
cls, hf_quant_cfg, user_quant
|
cls, hf_quant_cfg, user_quant
|
||||||
@@ -1699,8 +1699,351 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
|
||||||
|
"""FlashInfer TRTLLM MXFP8 block-scale MoE for ModelOpt checkpoints."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: ModelOptMxFp8Config,
|
||||||
|
moe_config: FusedMoEConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(moe_config)
|
||||||
|
self.quant_config = quant_config
|
||||||
|
assert self.quant_config.is_checkpoint_mxfp8_serialized
|
||||||
|
|
||||||
|
# Select MXFP8 MoE backend
|
||||||
|
self.mxfp8_backend = select_mxfp8_moe_backend(self.moe)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||||
|
layer.hidden_size = hidden_size
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
|
||||||
|
if hidden_size % MXFP8_BLOCK_SIZE != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"MXFP8 MoE requires hidden_size divisible by {MXFP8_BLOCK_SIZE}, "
|
||||||
|
f"got {hidden_size}."
|
||||||
|
)
|
||||||
|
if intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"MXFP8 MoE requires intermediate_size_per_partition divisible by "
|
||||||
|
f"{MXFP8_BLOCK_SIZE}, got {intermediate_size_per_partition}."
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||||
|
|
||||||
|
# GEMM 1 weights: [E, (2I or I), H]
|
||||||
|
w13_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
w13_num_shards * intermediate_size_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
dtype=MXFP8_VALUE_DTYPE,
|
||||||
|
),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
|
||||||
|
# GEMM 2 weights: [E, H, I]
|
||||||
|
w2_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
dtype=MXFP8_VALUE_DTYPE,
|
||||||
|
),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
|
||||||
|
# Per-block (K=32) E8M0 scales.
|
||||||
|
w13_weight_scale = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
w13_num_shards * intermediate_size_per_partition,
|
||||||
|
hidden_size // MXFP8_BLOCK_SIZE,
|
||||||
|
dtype=MXFP8_SCALE_DTYPE,
|
||||||
|
),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
|
w2_weight_scale = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
|
||||||
|
dtype=MXFP8_SCALE_DTYPE,
|
||||||
|
),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
# Ensure the generic MoE weight-loader treats these as block scales.
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.w13_weight_scale,
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value},
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.w2_weight_scale,
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_weight_dtypes(layer: torch.nn.Module) -> None:
|
||||||
|
"""Validate weight and scale dtypes before processing."""
|
||||||
|
expected = {
|
||||||
|
"w13_weight": MXFP8_VALUE_DTYPE,
|
||||||
|
"w2_weight": MXFP8_VALUE_DTYPE,
|
||||||
|
"w13_weight_scale": MXFP8_SCALE_DTYPE,
|
||||||
|
"w2_weight_scale": MXFP8_SCALE_DTYPE,
|
||||||
|
}
|
||||||
|
for name, expected_dtype in expected.items():
|
||||||
|
actual = getattr(layer, name).dtype
|
||||||
|
if actual != expected_dtype:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected {name} dtype {expected_dtype}, got {actual}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shuffle_weights_for_trtllm(self, layer: torch.nn.Module) -> None:
|
||||||
|
"""Shuffle weights and scales into FlashInfer TRTLLM MXFP8 layout."""
|
||||||
|
from flashinfer import (
|
||||||
|
reorder_rows_for_gated_act_gemm,
|
||||||
|
shuffle_matrix_a,
|
||||||
|
shuffle_matrix_sf_a,
|
||||||
|
)
|
||||||
|
|
||||||
|
epilogue_tile_m = 128
|
||||||
|
num_experts = layer.w13_weight.shape[0]
|
||||||
|
is_gated = self.moe.is_act_and_mul
|
||||||
|
intermediate_size_factor = 2 if is_gated else 1
|
||||||
|
|
||||||
|
w13_weight = layer.w13_weight.data
|
||||||
|
w13_scale = layer.w13_weight_scale.data
|
||||||
|
if is_gated:
|
||||||
|
# FI TRTLLM gated kernels use W31 ordering. Model checkpoints store
|
||||||
|
# gated projection as W13, so convert once before shuffling.
|
||||||
|
w13_weight = swap_w13_to_w31(w13_weight)
|
||||||
|
w13_scale = swap_w13_to_w31(w13_scale)
|
||||||
|
|
||||||
|
w13_weight_shuffled = []
|
||||||
|
w2_weight_shuffled = []
|
||||||
|
w13_scale_shuffled = []
|
||||||
|
w2_scale_shuffled = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
w13_i = w13_weight[i].reshape(
|
||||||
|
intermediate_size_factor * layer.intermediate_size_per_partition, -1
|
||||||
|
)
|
||||||
|
w13_sf_i = w13_scale[i].reshape(
|
||||||
|
intermediate_size_factor * layer.intermediate_size_per_partition, -1
|
||||||
|
)
|
||||||
|
if is_gated:
|
||||||
|
# Reorder rows for gated activation layout expected by TRTLLM.
|
||||||
|
w13_i = reorder_rows_for_gated_act_gemm(w13_i.clone())
|
||||||
|
w13_sf_i = reorder_rows_for_gated_act_gemm(w13_sf_i.clone())
|
||||||
|
|
||||||
|
w13_shuffled_i = shuffle_matrix_a(w13_i.view(torch.uint8), epilogue_tile_m)
|
||||||
|
w2_shuffled_i = shuffle_matrix_a(
|
||||||
|
layer.w2_weight.data[i].view(torch.uint8), epilogue_tile_m
|
||||||
|
)
|
||||||
|
w13_weight_shuffled.append(
|
||||||
|
w13_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE)
|
||||||
|
)
|
||||||
|
w2_weight_shuffled.append(
|
||||||
|
w2_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE)
|
||||||
|
)
|
||||||
|
w13_sf_shuffled_i = shuffle_matrix_sf_a(
|
||||||
|
w13_sf_i.view(torch.uint8).reshape(
|
||||||
|
intermediate_size_factor * layer.intermediate_size_per_partition,
|
||||||
|
-1,
|
||||||
|
),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
|
w2_sf_shuffled_i = shuffle_matrix_sf_a(
|
||||||
|
layer.w2_weight_scale.data[i]
|
||||||
|
.view(torch.uint8)
|
||||||
|
.reshape(layer.hidden_size, -1),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
|
w13_scale_shuffled.append(
|
||||||
|
w13_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE)
|
||||||
|
)
|
||||||
|
w2_scale_shuffled.append(
|
||||||
|
w2_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE)
|
||||||
|
)
|
||||||
|
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w13_weight", torch.stack(w13_weight_shuffled).contiguous()
|
||||||
|
)
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w2_weight", torch.stack(w2_weight_shuffled).contiguous()
|
||||||
|
)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
"w13_weight_scale",
|
||||||
|
torch.stack(w13_scale_shuffled).contiguous(),
|
||||||
|
)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
"w2_weight_scale",
|
||||||
|
torch.stack(w2_scale_shuffled).contiguous(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._check_weight_dtypes(layer)
|
||||||
|
self._shuffle_weights_for_trtllm(layer)
|
||||||
|
layer._already_called_process_weights_after_loading = True
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||||
|
"logic. This function should not be called."
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> mk.FusedMoEExpertsModular:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||||
|
"logic. This function should not be called."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> FusedMoEQuantConfig | None:
|
||||||
|
# TRTLLM MXFP8 path is monolithic and does not use modular kernel config.
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_monolithic(self) -> bool:
|
||||||
|
return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
|
||||||
|
|
||||||
|
def apply_monolithic(
|
||||||
|
self,
|
||||||
|
layer: FusedMoE,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from flashinfer.fused_moe.core import (
|
||||||
|
ActivationType,
|
||||||
|
Fp8QuantizationType,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
|
||||||
|
|
||||||
|
if layer.enable_eplb:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend."
|
||||||
|
)
|
||||||
|
|
||||||
|
supported_activations = [MoEActivation.SILU]
|
||||||
|
if layer.activation not in supported_activations:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashInfer TRTLLM MXFP8 MoE supports only "
|
||||||
|
f"{supported_activations}, got {layer.activation}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map vLLM MoEActivation to FlashInfer ActivationType.
|
||||||
|
activation_map = {
|
||||||
|
MoEActivation.SILU: ActivationType.Swiglu,
|
||||||
|
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||||
|
}
|
||||||
|
fi_activation_type: ActivationType = activation_map[layer.activation]
|
||||||
|
|
||||||
|
# DeepSeekV3 routing requires float32 logits; others expect bfloat16.
|
||||||
|
if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||||
|
assert router_logits.dtype == torch.float32, (
|
||||||
|
"DeepSeekV3 routing requires float32 router_logits, "
|
||||||
|
f"got {router_logits.dtype}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
router_logits = router_logits.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# Treat 0 as "unset" for compatibility with ungrouped routing configs.
|
||||||
|
n_group = layer.num_expert_group or None
|
||||||
|
topk_group = layer.topk_group or None
|
||||||
|
|
||||||
|
hidden_states_mxfp8, hidden_states_scale = mxfp8_e4m3_quantize(
|
||||||
|
x,
|
||||||
|
is_sf_swizzled_layout=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict = dict(
|
||||||
|
routing_logits=router_logits,
|
||||||
|
routing_bias=layer.e_score_correction_bias,
|
||||||
|
hidden_states=hidden_states_mxfp8,
|
||||||
|
hidden_states_scale=hidden_states_scale,
|
||||||
|
gemm1_weights=layer.w13_weight,
|
||||||
|
gemm1_weights_scale=layer.w13_weight_scale,
|
||||||
|
gemm2_weights=layer.w2_weight,
|
||||||
|
gemm2_weights_scale=layer.w2_weight_scale,
|
||||||
|
num_experts=layer.global_num_experts,
|
||||||
|
top_k=layer.top_k,
|
||||||
|
# Keep Optional semantics: FlashInfer expects None for non-grouped
|
||||||
|
# routing (e.g. Qwen3 Renormalize), not 0.
|
||||||
|
n_group=n_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
intermediate_size=layer.intermediate_size_per_partition,
|
||||||
|
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
|
local_num_experts=layer.local_num_experts,
|
||||||
|
routed_scaling_factor=layer.routed_scaling_factor,
|
||||||
|
routing_method_type=layer.routing_method_type,
|
||||||
|
use_shuffled_weight=True,
|
||||||
|
weight_layout=0,
|
||||||
|
fp8_quantization_type=Fp8QuantizationType.MxFp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if fi_activation_type != ActivationType.Swiglu:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashInfer TRTLLM MXFP8 MoE supports only Swiglu activation, "
|
||||||
|
f"got {fi_activation_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return flashinfer_trtllm_fp8_block_scale_moe(**kwargs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: FusedMoE,
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
shared_experts_input: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert not self.is_monolithic
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Non-monolithic MXFP8 MoE path is not yet implemented."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Register the method classes for ModelOptMxFp8Config
|
# Register the method classes for ModelOptMxFp8Config
|
||||||
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
||||||
|
ModelOptMxFp8Config.FusedMoEMethodCls = ModelOptMxFp8FusedMoE
|
||||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user