diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 5617156bf..2fcb7f193 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -6,6 +6,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -178,7 +179,40 @@ def triton_kernel_moe_forward( apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, ) -> torch.Tensor: + if ( + quant_config is not None + and quant_config.use_mxfp4_w4a8 + and rocm_aiter_ops.is_enabled() + ): + from aiter.ops.triton.moe_routing.routing import routing as aiter_routing + + routing_data, gather_idx, scatter_idx = aiter_routing( + gating_output, topk, sm_first=not renormalize + ) + return triton_kernel_fused_mxfp4_w4a8_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation.value, + quant_config=quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + unpadded_N_w1=unpadded_N_w1, + unpadded_K_w1=unpadded_K_w1, + unpadded_N_w2=unpadded_N_w2, + unpadded_K_w2=unpadded_K_w2, + ) + if expert_map is not None: # With expert parallelism, legacy_routing produces routing data # using global expert IDs which don't correspond to local weight @@ -210,6 +244,9 @@ def triton_kernel_moe_forward( effective_global_num_experts = global_num_experts output = torch.empty_like(hidden_states) + effective_quant_config = ( + quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG + ) return triton_kernel_fused_experts( output, @@ -221,7 +258,7 @@ def triton_kernel_moe_forward( scatter_idx, topk=topk, activation=activation, - quant_config=quant_config, + quant_config=effective_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=effective_global_num_experts, expert_map=effective_expert_map, @@ -252,8 +289,7 @@ def triton_kernel_fused_experts( assert activation == MoEActivation.SWIGLUOAI, ( "Only SWIGLUOAI activation is supported" ) - if quant_config is None: - quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + assert quant_config is not None # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 @@ -330,6 +366,98 @@ def triton_kernel_fused_experts( return output_tensor +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_mxfp4_w4a8_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + routing_data, # RoutingData + gather_indx, # GatherIndx + scatter_indx, # ScatterIndx + activation: str = "silu", + quant_config: FusedMoEQuantConfig | None = None, + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + a1q_scale: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, +) -> torch.Tensor: + assert quant_config is not None + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 + assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + gammas = routing_data.gate_scal if routing_data else None + + from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4 + from aiter.ops.triton.quant_moe import downcast_to_static_fp8 + + assert quant_config.w1_precision is not None, ( + "w1_precision in quant config can't be None" + ) + assert quant_config.w2_precision is not None, ( + "w2_precision in quant config can't be None" + ) + + hidden_states = downcast_to_static_fp8( + hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale + ) + + intermediate_cache1 = moe_gemm_a8w4( + hidden_states, + w1.storage.data, + None, + quant_config.w1_precision.weight_scale.storage.data, + quant_config.w1_precision.flex_ctx.lhs_data.scale, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + quant_config.w1_bias, + routing_data, + gather_indx=gather_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + out_dtype=torch.float8_e4m3fn, + apply_swiglu=True, + alpha=swiglu_alpha, + limit=swiglu_limit, + unpadded_N=unpadded_N_w1, + unpadded_K=unpadded_K_w1, + ) + + intermediate_cache3 = moe_gemm_a8w4( + intermediate_cache1, + w2.storage.data, + None, + quant_config.w2_precision.weight_scale.storage.data, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + None, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale="CDNA4_SCALE", + unpadded_N=unpadded_N_w2, + unpadded_K=unpadded_K_w2, + ) + + return intermediate_cache3 + + def make_routing_data( topk_ids: torch.Tensor, topk_weights: torch.Tensor, @@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): + if self.quant_config is None: + self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG + if expert_map is not None: topk_ids = expert_map[topk_ids] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 679b79ce9..a7dee7004 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -525,16 +525,18 @@ class FusedMoE(CustomOp): # Round up hidden size before creating moe_config. # This way moe_config is created with the correct hidden_size from the start. + unpadded_hidden_size = hidden_size + self.model_type = ( + self.vllm_config.model_config.hf_config.model_type + if self.vllm_config.model_config is not None + else None + ) hidden_size = maybe_roundup_hidden_size( hidden_size=hidden_size, act_dtype=moe_in_dtype, moe_parallel_config=self.moe_parallel_config, is_lora_enabled=vllm_config.lora_config is not None, - model_type=( - self.vllm_config.model_config.hf_config.model_type - if self.vllm_config.model_config is not None - else None - ), + model_type=self.model_type, is_mxfp4_quant=( quant_config is not None and quant_config.is_mxfp4_quant(prefix, self) ), @@ -610,6 +612,7 @@ class FusedMoE(CustomOp): moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, + "unpadded_hidden_size": unpadded_hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8394857cf..b2abbce1a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,8 +5,8 @@ from typing import Any import torch -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.logger import init_logger @@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, @@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up logger = init_logger(__name__) -__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"] +__all__ = [ + "QuarkMoEMethod", + "QuarkOCP_MX_MoEMethod", + "QuarkOCP_MX_MoEMethod_OSS", +] class QuarkMoEMethod(FusedMoEMethodBase): @@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase): "output_tensors and bias " "quantized are not supported" ) + weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") + if quant_config._is_fp8_w4a8(weight_config, input_config): return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config): - return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) + emulate = not current_platform.supports_mx() or not ( + rocm_aiter_ops.is_fused_moe_enabled() + ) + if ( + input_config.get("dtype") == "fp8_e4m3" + and not input_config.get("is_dynamic") + and not emulate + ): + return QuarkOCP_MX_MoEMethod_OSS( + weight_config, input_config, module.moe_config + ) + else: + return QuarkOCP_MX_MoEMethod( + weight_config, input_config, module.moe_config + ) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): get_current_vllm_config().model_config.hf_config, "model_type", None ) - self._emulate = ( + self.emulate = ( not current_platform.supports_mx() or not self.ocp_mx_scheme.startswith("w_mxfp4") ) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe) - self.emulate = True if self.model_type == "gpt_oss" else self._emulate - if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " @@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ) params_dtype = torch.uint8 + self.intermediate_size_per_partition = intermediate_size_per_partition if self.model_type == "gpt_oss": if current_platform.is_rocm(): intermediate_size_per_partition_after_pad = round_up( @@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): else: intermediate_size_per_partition_after_pad = intermediate_size_per_partition + self.unpadded_hidden_size = extra_weight_attrs.get( + "unpadded_hidden_size", hidden_size + ) + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if not self.emulate: - if ( - self.model_type == "gpt_oss" - and self.mxfp4_backend == Mxfp4Backend.TRITON - ): - raise NotImplementedError( - "Triton kernel implemented fused MoE for GPT_OSS model " - "in Quark(MoE) format is not integrated or provided yet." - ) + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) - else: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - ) - - return rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - quant_config=self.moe_quant_config, - expert_map=layer.expert_map, - ) + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + quant_config=self.moe_quant_config, + expert_map=layer.expert_map, + ) else: from vllm.model_executor.layers.fused_moe import fused_experts @@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) + + +class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(weight_config, input_config, moe) + + def process_weights_after_loading(self, layer): + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps + ) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + + if self.static_input_scales: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max().to(torch.float32), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max().to(torch.float32), requires_grad=False + ) + + from triton_kernels.numerics import InFlexData + + lhs_data13 = InFlexData(scale=layer.w13_input_scale) + lhs_data2 = InFlexData(scale=layer.w2_input_scale) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, + flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13), + ) + + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, + flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2), + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return mxfp4_w4a8_moe_quant_config( + w1_scale=self.w13_precision_config, + w2_scale=self.w2_precision_config, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + block_shape=None, + ) + + @property + def is_monolithic(self) -> bool: + return True + + def apply_monolithic( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + expert_map: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet." + ) + + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward, + ) + + return triton_kernel_moe_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=layer.top_k, + renormalize=layer.renormalize, + global_num_experts=layer.global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + unpadded_N_w1=self.intermediate_size_per_partition * 2, + unpadded_K_w1=self.unpadded_hidden_size, + unpadded_N_w2=self.unpadded_hidden_size, + unpadded_K_w2=self.intermediate_size_per_partition, + )