[gpt-oss] triton kernel mxfp4 (#22421)

Signed-off-by: <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu
2025-08-08 08:24:07 -07:00
committed by GitHub
parent e5ebeeba53
commit e789cad6b8
8 changed files with 755 additions and 9 deletions

View File

@@ -8,16 +8,19 @@ from torch.nn.parameter import Parameter
from vllm import envs
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_can_support_mxfp4)
_can_support_mxfp4, _swizzle_mxfp4)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import next_power_of_2, round_up
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
@@ -39,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 100
return 90
@classmethod
def get_name(cls) -> QuantizationMethods:
@@ -100,11 +103,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
@@ -303,7 +313,41 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
self.num_experts, -1),
requires_grad=False)
return
else:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
@@ -404,3 +448,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
True, # do finalize
)[0]
return trtllm_gen_output
else:
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_precision=self.w13_precision_config,
w2_precision=self.w2_precision_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)