[gpt-oss] triton kernel mxfp4 (#22421)
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -8,16 +8,19 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
triton_kernel_moe_forward)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_can_support_mxfp4)
|
||||
_can_support_mxfp4, _swizzle_mxfp4)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import next_power_of_2, round_up
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
@@ -39,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
@@ -100,11 +103,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
# other padding to increase performance
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.hidden_size = hidden_size
|
||||
@@ -303,7 +313,41 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
|
||||
self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
return
|
||||
else:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
@@ -404,3 +448,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
else:
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_precision=self.w13_precision_config,
|
||||
w2_precision=self.w2_precision_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@@ -4,11 +4,55 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
|
||||
"""
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and not is_torch_equal_or_newer("2.8.1")):
|
||||
logger.warning_once(
|
||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||
"this cause swizling to be disabled, which may "
|
||||
"cause performance degradation. Please upgrade to torch nightly")
|
||||
value_layout, value_layout_opts = StridedLayout, dict()
|
||||
scale_layout, scale_layout_opts = StridedLayout, dict()
|
||||
else:
|
||||
value_layout, value_layout_opts = \
|
||||
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
scale_layout, scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps))
|
||||
if current_platform.is_cuda() and \
|
||||
current_platform.is_device_capability(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
# transpose the tensor so that the quantization axis is on dim1
|
||||
quant_tensor = quant_tensor.transpose(-2, -1)
|
||||
scale = scale.transpose(-2, -1)
|
||||
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
|
||||
value_layout, **value_layout_opts)
|
||||
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
|
||||
**scale_layout_opts)
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
|
||||
Reference in New Issue
Block a user