[ROCm][Quantization] GPT OSS Upstream MoE wmxfp4_afp8 with static scales (#30357)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
committed by
GitHub
parent
31fb6f43da
commit
01923eec70
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -178,7 +179,40 @@ def triton_kernel_moe_forward(
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
quant_config is not None
|
||||
and quant_config.use_mxfp4_w4a8
|
||||
and rocm_aiter_ops.is_enabled()
|
||||
):
|
||||
from aiter.ops.triton.moe_routing.routing import routing as aiter_routing
|
||||
|
||||
routing_data, gather_idx, scatter_idx = aiter_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
return triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
None,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
activation=activation.value,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
unpadded_N_w1=unpadded_N_w1,
|
||||
unpadded_K_w1=unpadded_K_w1,
|
||||
unpadded_N_w2=unpadded_N_w2,
|
||||
unpadded_K_w2=unpadded_K_w2,
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
# With expert parallelism, legacy_routing produces routing data
|
||||
# using global expert IDs which don't correspond to local weight
|
||||
@@ -210,6 +244,9 @@ def triton_kernel_moe_forward(
|
||||
effective_global_num_experts = global_num_experts
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
effective_quant_config = (
|
||||
quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
output,
|
||||
@@ -221,7 +258,7 @@ def triton_kernel_moe_forward(
|
||||
scatter_idx,
|
||||
topk=topk,
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
quant_config=effective_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=effective_global_num_experts,
|
||||
expert_map=effective_expert_map,
|
||||
@@ -252,8 +289,7 @@ def triton_kernel_fused_experts(
|
||||
assert activation == MoEActivation.SWIGLUOAI, (
|
||||
"Only SWIGLUOAI activation is supported"
|
||||
)
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
assert quant_config is not None
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
@@ -330,6 +366,98 @@ def triton_kernel_fused_experts(
|
||||
return output_tensor
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
def triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
output_tensor: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
assert quant_config is not None
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4
|
||||
from aiter.ops.triton.quant_moe import downcast_to_static_fp8
|
||||
|
||||
assert quant_config.w1_precision is not None, (
|
||||
"w1_precision in quant config can't be None"
|
||||
)
|
||||
assert quant_config.w2_precision is not None, (
|
||||
"w2_precision in quant config can't be None"
|
||||
)
|
||||
|
||||
hidden_states = downcast_to_static_fp8(
|
||||
hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale
|
||||
)
|
||||
|
||||
intermediate_cache1 = moe_gemm_a8w4(
|
||||
hidden_states,
|
||||
w1.storage.data,
|
||||
None,
|
||||
quant_config.w1_precision.weight_scale.storage.data,
|
||||
quant_config.w1_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
out_dtype=torch.float8_e4m3fn,
|
||||
apply_swiglu=True,
|
||||
alpha=swiglu_alpha,
|
||||
limit=swiglu_limit,
|
||||
unpadded_N=unpadded_N_w1,
|
||||
unpadded_K=unpadded_K_w1,
|
||||
)
|
||||
|
||||
intermediate_cache3 = moe_gemm_a8w4(
|
||||
intermediate_cache1,
|
||||
w2.storage.data,
|
||||
None,
|
||||
quant_config.w2_precision.weight_scale.storage.data,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
None,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
unpadded_N=unpadded_N_w2,
|
||||
unpadded_K=unpadded_K_w2,
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def make_routing_data(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
if self.quant_config is None:
|
||||
self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
|
||||
@@ -525,16 +525,18 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# Round up hidden size before creating moe_config.
|
||||
# This way moe_config is created with the correct hidden_size from the start.
|
||||
unpadded_hidden_size = hidden_size
|
||||
self.model_type = (
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
)
|
||||
hidden_size = maybe_roundup_hidden_size(
|
||||
hidden_size=hidden_size,
|
||||
act_dtype=moe_in_dtype,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
model_type=(
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
),
|
||||
model_type=self.model_type,
|
||||
is_mxfp4_quant=(
|
||||
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
|
||||
),
|
||||
@@ -610,6 +612,7 @@ class FusedMoE(CustomOp):
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"unpadded_hidden_size": unpadded_hidden_size,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
|
||||
__all__ = [
|
||||
"QuarkMoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod_OSS",
|
||||
]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported"
|
||||
)
|
||||
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w4a8(weight_config, input_config):
|
||||
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||
emulate = not current_platform.supports_mx() or not (
|
||||
rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
if (
|
||||
input_config.get("dtype") == "fp8_e4m3"
|
||||
and not input_config.get("is_dynamic")
|
||||
and not emulate
|
||||
):
|
||||
return QuarkOCP_MX_MoEMethod_OSS(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
return QuarkOCP_MX_MoEMethod(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
get_current_vllm_config().model_config.hf_config, "model_type", None
|
||||
)
|
||||
|
||||
self._emulate = (
|
||||
self.emulate = (
|
||||
not current_platform.supports_mx()
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||
|
||||
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
params_dtype = torch.uint8
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
if self.model_type == "gpt_oss":
|
||||
if current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
|
||||
self.unpadded_hidden_size = extra_weight_attrs.get(
|
||||
"unpadded_hidden_size", hidden_size
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.emulate:
|
||||
if (
|
||||
self.model_type == "gpt_oss"
|
||||
and self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Triton kernel implemented fused MoE for GPT_OSS model "
|
||||
"in Quark(MoE) format is not integrated or provided yet."
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(weight_config, input_config, moe)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.static_input_scales:
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
from triton_kernels.numerics import InFlexData
|
||||
|
||||
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
|
||||
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
|
||||
)
|
||||
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return mxfp4_w4a8_moe_quant_config(
|
||||
w1_scale=self.w13_precision_config,
|
||||
w2_scale=self.w2_precision_config,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=layer.top_k,
|
||||
renormalize=layer.renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
unpadded_N_w1=self.intermediate_size_per_partition * 2,
|
||||
unpadded_K_w1=self.unpadded_hidden_size,
|
||||
unpadded_N_w2=self.unpadded_hidden_size,
|
||||
unpadded_K_w2=self.intermediate_size_per_partition,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user