From db5d0719e1057b876ed673ae6387d940962691bb Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 1 Apr 2026 18:41:42 +0200 Subject: [PATCH] [Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664) Signed-off-by: mgoin --- csrc/moe/marlin_moe_wna16/generate_kernels.py | 9 + csrc/moe/marlin_moe_wna16/marlin_template.h | 18 +- csrc/moe/marlin_moe_wna16/ops.cu | 3 + csrc/quantization/marlin/generate_kernels.py | 9 + csrc/quantization/marlin/marlin.cu | 3 + csrc/quantization/marlin/marlin_template.h | 17 +- tests/kernels/moe/test_moe.py | 6 + tests/models/quantization/test_mxfp8.py | 2 +- .../layers/fused_moe/fused_marlin_moe.py | 3 +- .../layers/fused_moe/oracle/fp8.py | 28 ++- .../layers/fused_moe/oracle/mxfp8.py | 14 +- .../layers/quantization/modelopt.py | 62 +----- .../layers/quantization/mxfp8.py | 56 ++--- .../quantization/utils/marlin_utils_fp8.py | 196 ++++++++++++++++++ .../layers/quantization/utils/mxfp8_utils.py | 184 ++++++++++++++-- 15 files changed, 481 insertions(+), 129 deletions(-) diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 52f266707..bca697cae 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -108,6 +108,15 @@ QUANT_CONFIGS = [ "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [2], }, + # MXFP8 + { + "a_type": ["kBFloat16"], + "b_type": "kFE4M3fn", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, # AWQ-INT4 with INT8 activation { "a_type": ["kS8"], diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index f5685b898..9858df945 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -343,6 +343,8 @@ __global__ void Marlin( if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); + } else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) { + static_assert(group_blocks == 2); } else if constexpr (std::is_same::value) { static_assert(s_type == vllm::kBFloat16); } else if constexpr (std::is_same::value) { @@ -357,9 +359,10 @@ __global__ void Marlin( constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || b_type == vllm::kS4 || b_type == vllm::kS8 || b_type == vllm::kU4B8 || b_type == vllm::kU8B128; + constexpr bool is_8bit_scale = s_type.size_bits() == 8; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - is_a_8bit || b_type == vllm::kFE4M3fn || + is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) || b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(b_type == vllm::kU8); @@ -373,7 +376,7 @@ __global__ void Marlin( const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; const int scales_expert_stride = - prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8); + prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); @@ -692,9 +695,8 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); - constexpr int s_sh_stride = - 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); + int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks @@ -1131,7 +1133,7 @@ __global__ void Marlin( int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (b_type_id != vllm::kFE2M1f.id()) { + if constexpr (!is_8bit_scale) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } else { @@ -1140,7 +1142,7 @@ __global__ void Marlin( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } else if (group_blocks >= b_sh_wr_iters) { - if constexpr (b_type_id != vllm::kFE2M1f.id()) { + if constexpr (!is_8bit_scale) { reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; } else { @@ -1341,7 +1343,7 @@ __global__ void Marlin( } } - if constexpr (b_type == vllm::kFE2M1f) { + if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 60681ad93..cf97f95a8 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm( "When b_type = float4_e2m1f, b_scale scalar type must be", "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); } + } else if (b_type_id == vllm::kFE4M3fn.id() && + b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); } vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); diff --git a/csrc/quantization/marlin/generate_kernels.py b/csrc/quantization/marlin/generate_kernels.py index 5ecbc6ac9..19a42de1f 100644 --- a/csrc/quantization/marlin/generate_kernels.py +++ b/csrc/quantization/marlin/generate_kernels.py @@ -108,6 +108,15 @@ QUANT_CONFIGS = [ "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [2], }, + # MXFP8 + { + "a_type": ["kBFloat16"], + "b_type": "kFE4M3fn", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, # AWQ-INT4 with INT8 activation { "a_type": ["kS8"], diff --git a/csrc/quantization/marlin/marlin.cu b/csrc/quantization/marlin/marlin.cu index fbdb619c2..5684f272e 100644 --- a/csrc/quantization/marlin/marlin.cu +++ b/csrc/quantization/marlin/marlin.cu @@ -591,6 +591,9 @@ torch::Tensor marlin_gemm( "When b_type = float4_e2m1f, b_scale scalar type must be", "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); } + } else if (b_type_id == vllm::kFE4M3fn.id() && + b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); } vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); diff --git a/csrc/quantization/marlin/marlin_template.h b/csrc/quantization/marlin/marlin_template.h index 9e625b645..32b8f8bdd 100644 --- a/csrc/quantization/marlin/marlin_template.h +++ b/csrc/quantization/marlin/marlin_template.h @@ -327,6 +327,9 @@ __global__ void Marlin( if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); + } else if constexpr (s_type == vllm::kFE8M0fnu) { + // MXFP8: FP8 weights with e8m0 microscaling block scales + static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2); } else if constexpr (std::is_same::value) { static_assert(s_type == vllm::kBFloat16); } else if constexpr (std::is_same::value) { @@ -334,6 +337,7 @@ __global__ void Marlin( } constexpr bool is_a_8bit = a_type.size_bits() == 8; + constexpr bool is_8bit_scale = s_type.size_bits() == 8; if constexpr (!is_a_8bit) { static_assert(std::is_same::value); } @@ -343,7 +347,7 @@ __global__ void Marlin( b_type == vllm::kU4B8 || b_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - is_a_8bit || b_type == vllm::kFE4M3fn || + is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) || b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(b_type == vllm::kU8); @@ -555,9 +559,8 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); - constexpr int s_sh_stride = - 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); + int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks @@ -997,7 +1000,7 @@ __global__ void Marlin( int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (b_type_id != vllm::kFE2M1f.id()) { + if constexpr (!is_8bit_scale) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } else { @@ -1006,7 +1009,7 @@ __global__ void Marlin( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } else if (group_blocks >= b_sh_wr_iters) { - if constexpr (b_type_id != vllm::kFE2M1f.id()) { + if constexpr (!is_8bit_scale) { reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; } else { @@ -1207,7 +1210,7 @@ __global__ void Marlin( } } - if constexpr (b_type == vllm::kFE2M1f) { + if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 4a3941b3d..5ec2b106e 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -151,6 +151,12 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [ "b_type": scalar_types.float4_e2m1f, "group_blocks": [2], }, + # MXFP8 + { + "a_type": [scalar_types.bfloat16], + "b_type": scalar_types.float8_e4m3fn, + "group_blocks": [2], + }, # AWQ-INT4 with INT8 activation { "a_type": [scalar_types.int8], diff --git a/tests/models/quantization/test_mxfp8.py b/tests/models/quantization/test_mxfp8.py index 2cb0f2008..7c250d115 100644 --- a/tests/models/quantization/test_mxfp8.py +++ b/tests/models/quantization/test_mxfp8.py @@ -23,7 +23,7 @@ from tests.quantization.utils import is_quant_method_supported from ..utils import check_logprobs_close # A small MoE model that fits on a single GPU and has both linear + MoE layers. -MOE_MODEL = "Qwen/Qwen3-30B-A3B" +MOE_MODEL = "allenai/OLMoE-1B-7B-0125-Instruct" # A small dense model (no MoE) to validate the linear-only path. DENSE_MODEL = "Qwen/Qwen3-0.6B" diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 136a8188d..6c916cf3c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticChannelSym, kFp8StaticTensorSym, kMxfp4Static, + kMxfp8Static, kNvfp4Static, ) from vllm.platforms import current_platform @@ -582,6 +583,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular): kFp8StaticChannelSym, kFp8StaticTensorSym, kMxfp4Static, + kMxfp8Static, kNvfp4Static, ] return weight_key in SUPPORTED_W @@ -609,7 +611,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular): @property def quant_type_id(self) -> int: - # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4 if self.quant_config.use_int4_w4a16: return scalar_types.uint4b8.id elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 3d9a49902..36f35ed5e 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -436,13 +436,27 @@ def convert_to_fp8_moe_kernel_format( elif fp8_backend == Fp8MoeBackend.AITER: w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2) elif fp8_backend == Fp8MoeBackend.MARLIN: - w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin( - layer, - w13, - w2, - w13_scale, - w2_scale, - ) + weight_block_size = getattr(layer, "weight_block_size", None) + if weight_block_size == [1, 32]: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_mxfp8_moe_layer_for_marlin, + ) + + w13, w2, w13_scale, w2_scale = prepare_mxfp8_moe_layer_for_marlin( + layer, + w13, + w2, + w13_scale, + w2_scale, + ) + else: + w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin( + layer, + w13, + w2, + w13_scale, + w2_scale, + ) elif fp8_backend in [ Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_TRTLLM, diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py index ed3af4b5a..c67def149 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py @@ -15,14 +15,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( logger = init_logger(__name__) -_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset( - { - Fp8MoeBackend.FLASHINFER_TRTLLM, - } +_SUPPORTED_BACKENDS = ( + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.MARLIN, ) _BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = { "flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM, + "marlin": Fp8MoeBackend.MARLIN, } @@ -81,7 +81,11 @@ def select_mxfp8_moe_backend( # Auto-select: pick the first supported backend. for backend in _SUPPORTED_BACKENDS: + try: + experts_cls = _select_kernel_cls(backend, config) + except ValueError: + continue logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value) - return backend, _select_kernel_cls(backend, config) + return backend, experts_cls raise ValueError("No MXFP8 MoE backends available.") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b0562ee43..7871b774e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -67,10 +67,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( MXFP8_BLOCK_SIZE, MXFP8_SCALE_DTYPE, MXFP8_VALUE_DTYPE, - Mxfp8LinearBackend, Mxfp8LinearOp, mxfp8_e4m3_quantize, - swizzle_mxfp8_scale, ) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( apply_nvfp4_linear, @@ -1499,8 +1497,8 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase): @classmethod def get_min_capability(cls) -> int: - # MXFP8 hardware acceleration requires Blackwell (SM100) or newer - return 100 + # Marlin kernel supports MXFP8 on SM80+ + return 80 @classmethod def override_quantization_method( @@ -1555,9 +1553,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): "Dynamic quantization is not supported." ) - self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS - self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend) - logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value) + self.mxfp8_linear_op = Mxfp8LinearOp() def create_weights( self, @@ -1615,36 +1611,6 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): ) layer.register_parameter("weight_scale", weight_scale) - def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None: - """Not swizzled - MXFP8 GEMM emulation""" - weight = layer.weight.data # [N, K] - N, K = weight.shape - scale_k = K // MXFP8_BLOCK_SIZE - - # Slice weight_scale to match weight dimensions (handles padding) - weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous() - - layer.weight = Parameter(weight.contiguous(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - - def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None: - """Swizzled - MXFP8 GEMM Flashinfer CUTLASS""" - weight = layer.weight.data # [N, K] - N, K = weight.shape - - # 2D weight scale - weight_scale = layer.weight_scale.data - - # Swizzle the weight scales - scale_k = K // MXFP8_BLOCK_SIZE - weight_scale_2d = weight_scale[:N, :scale_k].contiguous() - weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K) - - layer.weight = Parameter(weight.contiguous(), requires_grad=False) - layer.weight_scale = Parameter( - weight_scale_swizzled.contiguous(), requires_grad=False - ) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Validate weight tensor if layer.weight.ndim != 2: @@ -1669,14 +1635,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): f" got {layer.weight_scale.dtype}" ) - if self.backend == Mxfp8LinearBackend.EMULATION: - # Swizzled layout is not used - self._process_weights_after_loading_scale_2d(layer) - return - - assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS - # Swizzled layout is required for Flashinfer CUTLASS - self._process_weights_after_loading_scale_1d(layer) + self.mxfp8_linear_op.process_weights(layer) def apply( self, @@ -1684,22 +1643,15 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if layer.weight.dtype != MXFP8_VALUE_DTYPE: - raise ValueError( - f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}" - ) - if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE: - raise ValueError( - f"Weight scale dtype {layer.weight_scale.dtype} != " - f"expected {MXFP8_SCALE_DTYPE}" - ) - return self.mxfp8_linear_op.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=x.dtype, bias=bias, + workspace=getattr(layer, "workspace", None), + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, ) diff --git a/vllm/model_executor/layers/quantization/mxfp8.py b/vllm/model_executor/layers/quantization/mxfp8.py index 07c519b2a..6e0c14143 100644 --- a/vllm/model_executor/layers/quantization/mxfp8.py +++ b/vllm/model_executor/layers/quantization/mxfp8.py @@ -34,10 +34,8 @@ from vllm.model_executor.layers.quantization.fp8 import ( ) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( MXFP8_BLOCK_SIZE, - Mxfp8LinearBackend, Mxfp8LinearOp, mxfp8_e4m3_quantize, - swizzle_mxfp8_scale, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, @@ -71,7 +69,8 @@ class Mxfp8Config(Fp8Config): @classmethod def get_min_capability(cls) -> int: - return 100 + # Marlin kernel supports MXFP8 on SM80+ + return 80 @classmethod def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config": @@ -128,24 +127,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): def __init__(self, quant_config: "Mxfp8Config"): self.quant_config = quant_config self.out_dtype = torch.get_default_dtype() - self.mxfp8_linear = Mxfp8LinearOp(self._select_backend()) - logger.info_once( - "Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value - ) - - @staticmethod - def _select_backend() -> Mxfp8LinearBackend: - try: - from vllm.utils import flashinfer as fi - - _ = fi.mm_mxfp8 - return Mxfp8LinearBackend.FLASHINFER_CUTLASS - except Exception: - logger.warning( - "FlashInfer mm_mxfp8 not available, " - "falling back to MXFP8 emulation backend." - ) - return Mxfp8LinearBackend.EMULATION + self.mxfp8_linear = Mxfp8LinearOp() def create_weights( self, @@ -180,14 +162,12 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous()) - if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS: - N, K = layer.weight.shape[0], layer.weight.shape[1] - weight_scale = swizzle_mxfp8_scale(weight_scale, N, K) - layer.input_scale = None replace_parameter(layer, "weight", weight_fp8.data) replace_parameter(layer, "weight_scale", weight_scale.data) + self.mxfp8_linear.process_weights(layer) + layer._already_called_process_weights_after_loading = True def apply( @@ -202,6 +182,9 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): weight_scale=layer.weight_scale, out_dtype=self.out_dtype, bias=bias, + workspace=getattr(layer, "workspace", None), + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, ) @@ -255,17 +238,24 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): self, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales).""" - num_batches = weight.size(0) - w_quant = [] - w_scales = [] - for i in range(num_batches): - mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize( + E = weight.size(0) + first_q, first_s = mxfp8_e4m3_quantize(weight[0], is_sf_swizzled_layout=False) + # Pre-allocate the output tensors rather than stacking. + # This is important for consistent memory layout. + w_quant = torch.empty( + (E, *first_q.shape), dtype=first_q.dtype, device=weight.device + ) + w_scales = torch.empty( + (E, *first_s.shape), dtype=first_s.dtype, device=weight.device + ) + w_quant[0] = first_q + w_scales[0] = first_s + for i in range(1, E): + w_quant[i], w_scales[i] = mxfp8_e4m3_quantize( weight[i], is_sf_swizzled_layout=False ) - w_quant.append(mx_fp8_quant) - w_scales.append(mx_fp8_scale) - return torch.stack(w_quant), torch.stack(w_scales) + return w_quant, w_scales def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index b5a557ce9..6e2ae5c91 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -336,6 +336,202 @@ def pack_fp8_to_int32( return int32_tensor.T.contiguous() if size_k_first else int32_tensor +def mxfp8_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor: + """Reorder scales for e8m0 kernel layout and convert to float8_e8m0fnu.""" + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + return marlin_scales + + +def apply_mxfp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: torch.Tensor | None = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype, + ) + + output = ops.marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_bias=bias, + b_scales=weight_scale, + a_scales=None, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) + + return output.reshape(out_shape) + + +def prepare_mxfp8_layer_for_marlin(layer: torch.nn.Module) -> None: + """Repack MXFP8 weights and scales into Marlin kernel format. + + Expects the layer to have: + - weight: [N, K] float8_e4m3fn + - weight_scale: [N, K//32] uint8 (e8m0 encoded) + - input_size_per_partition / output_size_per_partition + """ + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + group_size = 32 # MX standard block size + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT - repack FP8 weights to Marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = pack_fp8_to_int32(layer.weight, size_k_first=False) + qweight = qweight.T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + replace_parameter(layer, "weight", marlin_qweight) + + # WEIGHT SCALES + # Convert uint8 scales -> e8m0fnu -> param_dtype for permutation + # Scales are [N, K//32], need [K//32, N] for marlin_permute_scales + param_dtype = torch.get_default_dtype() + scales = layer.weight_scale.data[:part_size_n, : part_size_k // group_size] + scales = scales.contiguous() + scales = scales.view(torch.float8_e8m0fnu).to(param_dtype) + scales = scales.T.contiguous() + + # Permute scales to Marlin layout + marlin_scales = marlin_permute_scales( + s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=group_size, + ) + + # Reorder for e8m0 kernel layout and convert back to e8m0fnu + marlin_scales = mxfp8_marlin_process_scales(marlin_scales) + replace_parameter(layer, "weight_scale", marlin_scales) + + # BIAS + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + replace_parameter(layer, "bias", bias) + + +def prepare_mxfp8_moe_layer_for_marlin( + layer: torch.nn.Module, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Repack MXFP8 MoE weights and scales into Marlin kernel format. + + Args: + layer: MoE layer (used to read params_dtype and attach workspace). + w13: [E, 2*N, K] float8_e4m3fn weights. + w2: [E, K, N] float8_e4m3fn weights. + w13_scale: [E, 2*N, K//32] uint8 e8m0 scales. + w2_scale: [E, K, N//32] uint8 e8m0 scales. + + Returns: + (w13, w2, w13_scale, w2_scale) in Marlin format. + """ + group_size = 32 + e = w13.shape[0] + w13_n = w13.shape[1] + k = w13.shape[2] + n = w2.shape[2] + + device = w13.device + param_dtype = torch.get_default_dtype() + perm = torch.empty(0, dtype=torch.int, device=device) + + layer.workspace = marlin_make_workspace_new(device, 4) + + def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor: + if "w13" in name: + size_n, size_k = w13_n, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k) + + tensor_list = [] + for i in range(e): + qweight = pack_fp8_to_int32(weight[i], size_k_first=False) + qweight = qweight.T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=8, + ) + tensor_list.append(marlin_qweight) + return torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + + w13 = repack_weight(w13, "w13") + w2 = repack_weight(w2, "w2") + + def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor: + if "w13" in name: + size_n, size_k = w13_n, k + else: + size_n, size_k = k, n + + tensor_list = [] + for i in range(e): + s = scales[i][:size_n, : size_k // group_size].contiguous() + s = s.view(torch.float8_e8m0fnu).to(param_dtype) + s = s.T.contiguous() + marlin_s = marlin_permute_scales( + s=s, + size_k=size_k, + size_n=size_n, + group_size=group_size, + ) + marlin_s = mxfp8_marlin_process_scales(marlin_s) + tensor_list.append(marlin_s) + return torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + + w13_scale = permute_scales(w13_scale, "w13") + w2_scale = permute_scales(w2_scale, "w2") + + return w13, w2, w13_scale, w2_scale + + def marlin_quant_fp8_torch(weight, group_size, input_dtype=None): is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 if is_a_8bit: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index ee849b167..f10c823f5 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -4,6 +4,7 @@ from enum import Enum import torch +from torch.nn.parameter import Parameter from vllm.logger import init_logger from vllm.utils import flashinfer as vllm_flashinfer @@ -15,6 +16,7 @@ logger = init_logger(__name__) class Mxfp8LinearBackend(Enum): EMULATION = "emulation" FLASHINFER_CUTLASS = "flashinfer-cutlass" + MARLIN = "marlin" # MXFP8 constants @@ -23,6 +25,28 @@ MXFP8_SCALE_DTYPE = torch.uint8 MXFP8_BLOCK_SIZE = 32 +def select_mxfp8_linear_backend() -> Mxfp8LinearBackend: + """Select the best MXFP8 linear backend for the current device. + + - SM100+ (Blackwell): FLASHINFER_CUTLASS (native MXFP8 W8A8 GEMM) + - SM80+ (Ampere/Ada): MARLIN (MXFP8 W8A16 GEMM) + - Otherwise: EMULATION (dequant to BF16 fallback) + """ + from vllm.platforms import current_platform + + if current_platform.has_device_capability(100): + return Mxfp8LinearBackend.FLASHINFER_CUTLASS + + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + is_fp8_marlin_supported, + ) + + if is_fp8_marlin_supported(): + return Mxfp8LinearBackend.MARLIN + + return Mxfp8LinearBackend.EMULATION + + def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: """Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout.""" scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8 @@ -47,17 +71,71 @@ def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: return sf_swizzled.contiguous().view(-1) +def _mxfp8_e4m3_quantize_torch( + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Naive MXFP8 quantization. + For each block of 32 elements along the last dimension, compute a + shared e8m0 scale (the biased exponent of the block-wise amax) + and quantize each element to float8_e4m3fn. + + Returns (quantized_values [same shape, fp8], scales uint8). + Scale shape depends on is_sf_swizzled_layout: + False -> [..., K//32] (row-major 2D) + True -> [flat swizzled 1D] + """ + assert x.shape[-1] % MXFP8_BLOCK_SIZE == 0 + orig_shape = x.shape + num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE + + x_fp32 = x.to(torch.float32) + x_blocked = x_fp32.view(*orig_shape[:-1], num_blocks, MXFP8_BLOCK_SIZE) + + amax = x_blocked.abs().amax(dim=-1) + amax = amax.clamp(min=torch.finfo(torch.float32).tiny) + scale_biased = torch.floor(torch.log2(amax)) + 127.0 + scale_biased = scale_biased.clamp(0, 254) + scales_uint8 = scale_biased.to(torch.uint8) + + descale = torch.exp2(scale_biased - 127.0) + x_scaled = x_blocked / descale.unsqueeze(-1) + + x_fp8 = x_scaled.view(orig_shape).to(MXFP8_VALUE_DTYPE) + + if x.ndim == 2: + M, K = x.shape + scales_uint8 = scales_uint8.view(M, -1) + if is_sf_swizzled_layout: + scales_uint8 = swizzle_mxfp8_scale(scales_uint8, M=M, K=K) + elif x.ndim == 3: + B, M, K = x.shape + scales_uint8 = scales_uint8.view(B, M, -1) + if is_sf_swizzled_layout: + swizzled = [] + for i in range(B): + swizzled.append(swizzle_mxfp8_scale(scales_uint8[i], M=M, K=K)) + scales_uint8 = torch.cat(swizzled) + + return x_fp8, scales_uint8 + + def _mxfp8_e4m3_quantize_impl( x: torch.Tensor, is_sf_swizzled_layout: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: - from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize + from vllm.platforms import current_platform - x_q, x_scales = flashinfer_mxfp8_quantize( - x, is_sf_swizzled_layout=is_sf_swizzled_layout - ) - if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout: - x_scales = x_scales.view(x.size(0), -1) - return x_q, x_scales + if current_platform.has_device_capability(100): + from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize + + x_q, x_scales = flashinfer_mxfp8_quantize( + x, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout: + x_scales = x_scales.view(x.size(0), -1) + return x_q, x_scales + + return _mxfp8_e4m3_quantize_torch(x, is_sf_swizzled_layout) def mxfp8_e4m3_quantize( @@ -128,11 +206,51 @@ direct_register_custom_op( class Mxfp8LinearOp: - def __init__(self, backend: Mxfp8LinearBackend): - if backend not in Mxfp8LinearBackend: - raise ValueError(f"Unsupported backend: {backend}") + def __init__(self): + self.backend = select_mxfp8_linear_backend() + logger.info_once("Using %s backend for MXFP8 GEMM", self.backend) - self.backend = backend + def process_weights(self, layer: torch.nn.Module) -> None: + """Process MXFP8 weights after loading into backend-specific format.""" + if self.backend == Mxfp8LinearBackend.MARLIN: + self._process_weights_marlin(layer) + elif self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS: + self._process_weights_flashinfer_cutlass(layer) + else: + self._process_weights_emulation(layer) + + def _process_weights_emulation(self, layer: torch.nn.Module) -> None: + """Keep scales as 2D uint8 for dequant-to-BF16 emulation.""" + weight = layer.weight.data # [N, K] + N, K = weight.shape + scale_k = K // MXFP8_BLOCK_SIZE + + weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous() + + layer.weight = Parameter(weight.contiguous(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + def _process_weights_flashinfer_cutlass(self, layer: torch.nn.Module) -> None: + """Swizzle scales to F8_128x4 layout for flashinfer CUTLASS.""" + weight = layer.weight.data # [N, K] + N, K = weight.shape + + scale_k = K // MXFP8_BLOCK_SIZE + weight_scale_2d = layer.weight_scale.data[:N, :scale_k].contiguous() + weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K) + + layer.weight = Parameter(weight.contiguous(), requires_grad=False) + layer.weight_scale = Parameter( + weight_scale_swizzled.contiguous(), requires_grad=False + ) + + def _process_weights_marlin(self, layer: torch.nn.Module) -> None: + """Repack MXFP8 weights and scales into Marlin kernel format.""" + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_mxfp8_layer_for_marlin, + ) + + prepare_mxfp8_layer_for_marlin(layer) def _apply_emulation( self, @@ -142,7 +260,6 @@ class Mxfp8LinearOp: out_dtype: torch.dtype, bias: torch.Tensor | None = None, ) -> torch.Tensor: - # Validate weight_scale dtype and shape (must be 2D for TORCH backend) if weight_scale.dtype != MXFP8_SCALE_DTYPE: raise ValueError( f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, " @@ -219,6 +336,32 @@ class Mxfp8LinearOp: output_shape = (*input_shape[:-1], N) return output.view(output_shape) + def _apply_marlin( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, + *, + workspace: torch.Tensor, + size_n: int, + size_k: int, + ) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_mxfp8_marlin_linear, + ) + + return apply_mxfp8_marlin_linear( + input=input, + weight=weight, + weight_scale=weight_scale, + workspace=workspace, + size_n=size_n, + size_k=size_k, + bias=bias, + ) + def apply( self, input: torch.Tensor, @@ -226,10 +369,27 @@ class Mxfp8LinearOp: weight_scale: torch.Tensor, out_dtype: torch.dtype, bias: torch.Tensor | None = None, + *, + workspace: torch.Tensor | None = None, + size_n: int = 0, + size_k: int = 0, ) -> torch.Tensor: if self.backend == Mxfp8LinearBackend.EMULATION: return self._apply_emulation(input, weight, weight_scale, out_dtype, bias) + if self.backend == Mxfp8LinearBackend.MARLIN: + assert workspace is not None + return self._apply_marlin( + input, + weight, + weight_scale, + out_dtype, + bias, + workspace=workspace, + size_n=size_n, + size_k=size_k, + ) + assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS return self._apply_flashinfer_cutlass( input, weight, weight_scale, out_dtype, bias