[ROCm][Quantization] GPT OSS Upstream MoE wmxfp4_afp8 with static scales (#30357)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev
2026-02-26 14:50:16 -08:00
committed by GitHub
parent 31fb6f43da
commit 01923eec70
3 changed files with 315 additions and 37 deletions

View File

@@ -6,6 +6,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
@@ -178,7 +179,40 @@ def triton_kernel_moe_forward(
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
unpadded_N_w1=None,
unpadded_K_w1=None,
unpadded_N_w2=None,
unpadded_K_w2=None,
) -> torch.Tensor:
if (
quant_config is not None
and quant_config.use_mxfp4_w4a8
and rocm_aiter_ops.is_enabled()
):
from aiter.ops.triton.moe_routing.routing import routing as aiter_routing
routing_data, gather_idx, scatter_idx = aiter_routing(
gating_output, topk, sm_first=not renormalize
)
return triton_kernel_fused_mxfp4_w4a8_experts(
None,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
activation=activation.value,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
unpadded_N_w1=unpadded_N_w1,
unpadded_K_w1=unpadded_K_w1,
unpadded_N_w2=unpadded_N_w2,
unpadded_K_w2=unpadded_K_w2,
)
if expert_map is not None:
# With expert parallelism, legacy_routing produces routing data
# using global expert IDs which don't correspond to local weight
@@ -210,6 +244,9 @@ def triton_kernel_moe_forward(
effective_global_num_experts = global_num_experts
output = torch.empty_like(hidden_states)
effective_quant_config = (
quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG
)
return triton_kernel_fused_experts(
output,
@@ -221,7 +258,7 @@ def triton_kernel_moe_forward(
scatter_idx,
topk=topk,
activation=activation,
quant_config=quant_config,
quant_config=effective_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=effective_global_num_experts,
expert_map=effective_expert_map,
@@ -252,8 +289,7 @@ def triton_kernel_fused_experts(
assert activation == MoEActivation.SWIGLUOAI, (
"Only SWIGLUOAI activation is supported"
)
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
assert quant_config is not None
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
@@ -330,6 +366,98 @@ def triton_kernel_fused_experts(
return output_tensor
# This is a triton implementation of the fused_experts function
def triton_kernel_fused_mxfp4_w4a8_experts(
output_tensor: torch.Tensor,
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
w2, # Tensor or triton_kernels.Tensor
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None,
unpadded_N_w1=None,
unpadded_K_w1=None,
unpadded_N_w2=None,
unpadded_K_w2=None,
) -> torch.Tensor:
assert quant_config is not None
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
gammas = routing_data.gate_scal if routing_data else None
from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4
from aiter.ops.triton.quant_moe import downcast_to_static_fp8
assert quant_config.w1_precision is not None, (
"w1_precision in quant config can't be None"
)
assert quant_config.w2_precision is not None, (
"w2_precision in quant config can't be None"
)
hidden_states = downcast_to_static_fp8(
hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale
)
intermediate_cache1 = moe_gemm_a8w4(
hidden_states,
w1.storage.data,
None,
quant_config.w1_precision.weight_scale.storage.data,
quant_config.w1_precision.flex_ctx.lhs_data.scale,
quant_config.w2_precision.flex_ctx.lhs_data.scale,
quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
gammas=gammas if apply_router_weight_on_input else None,
swizzle_mx_scale="CDNA4_SCALE",
out_dtype=torch.float8_e4m3fn,
apply_swiglu=True,
alpha=swiglu_alpha,
limit=swiglu_limit,
unpadded_N=unpadded_N_w1,
unpadded_K=unpadded_K_w1,
)
intermediate_cache3 = moe_gemm_a8w4(
intermediate_cache1,
w2.storage.data,
None,
quant_config.w2_precision.weight_scale.storage.data,
quant_config.w2_precision.flex_ctx.lhs_data.scale,
None,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
gammas=None if apply_router_weight_on_input else gammas,
swizzle_mx_scale="CDNA4_SCALE",
unpadded_N=unpadded_N_w2,
unpadded_K=unpadded_K_w2,
)
return intermediate_cache3
def make_routing_data(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
@@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
if self.quant_config is None:
self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG
if expert_map is not None:
topk_ids = expert_map[topk_ids]

View File

@@ -525,16 +525,18 @@ class FusedMoE(CustomOp):
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
unpadded_hidden_size = hidden_size
self.model_type = (
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
)
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=(
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
),
model_type=self.model_type,
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
@@ -610,6 +612,7 @@ class FusedMoE(CustomOp):
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"unpadded_hidden_size": unpadded_hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,

View File

@@ -5,8 +5,8 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
__all__ = [
"QuarkMoEMethod",
"QuarkOCP_MX_MoEMethod",
"QuarkOCP_MX_MoEMethod_OSS",
]
class QuarkMoEMethod(FusedMoEMethodBase):
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
"output_tensors and bias "
"quantized are not supported"
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
emulate = not current_platform.supports_mx() or not (
rocm_aiter_ops.is_fused_moe_enabled()
)
if (
input_config.get("dtype") == "fp8_e4m3"
and not input_config.get("is_dynamic")
and not emulate
):
return QuarkOCP_MX_MoEMethod_OSS(
weight_config, input_config, module.moe_config
)
else:
return QuarkOCP_MX_MoEMethod(
weight_config, input_config, module.moe_config
)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
get_current_vllm_config().model_config.hf_config, "model_type", None
)
self._emulate = (
self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype = torch.uint8
self.intermediate_size_per_partition = intermediate_size_per_partition
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
self.unpadded_hidden_size = extra_weight_attrs.get(
"unpadded_hidden_size", hidden_size
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
if (
self.model_type == "gpt_oss"
and self.mxfp4_backend == Mxfp4Backend.TRITON
):
raise NotImplementedError(
"Triton kernel implemented fused MoE for GPT_OSS model "
"in Quark(MoE) format is not integrated or provided yet."
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
else:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(weight_config, input_config, moe)
def process_weights_after_loading(self, layer):
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
if self.static_input_scales:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
)
from triton_kernels.numerics import InFlexData
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale,
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale,
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return mxfp4_w4a8_moe_quant_config(
w1_scale=self.w13_precision_config,
w2_scale=self.w2_precision_config,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
block_shape=None,
)
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
)
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
unpadded_N_w1=self.intermediate_size_per_partition * 2,
unpadded_K_w1=self.unpadded_hidden_size,
unpadded_N_w2=self.unpadded_hidden_size,
unpadded_K_w2=self.intermediate_size_per_partition,
)