Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,32 +9,45 @@ from torch.nn.parameter import Parameter
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config)
FusedMoEQuantConfig,
mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts)
OAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin)
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_can_support_mxfp4, _swizzle_mxfp4)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
_can_support_mxfp4,
_swizzle_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
next_power_of_2, round_up)
from vllm.utils import (
has_triton_kernels,
is_torch_equal_or_newer,
next_power_of_2,
round_up,
)
from vllm.utils.flashinfer import has_flashinfer
logger = init_logger(__name__)
@@ -60,42 +73,57 @@ class Mxfp4Backend(Enum):
def get_mxfp4_backend():
# Backend Selection
if current_platform.is_cuda():
if (current_platform.is_device_capability(90) and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
if (
current_platform.is_device_capability(90)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
):
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
return Mxfp4Backend.SM90_FI_MXFP4_BF16
elif (current_platform.is_device_capability(100) and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS):
logger.info_once(
"Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
elif (
current_platform.is_device_capability(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
elif (current_platform.is_device_capability(100) and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
elif (
current_platform.is_device_capability(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
logger.info_once(
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
"for high concurrency throughput workloads consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
"performance")
"performance"
)
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
elif current_platform.is_device_capability(100) and has_flashinfer():
logger.info_once(
"Using FlashInfer MXFP4 BF16 backend for SM100, "
"For faster performance on SM100, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
"accuracy.")
"accuracy."
)
return Mxfp4Backend.SM100_FI_MXFP4_BF16
elif ((current_platform.is_device_capability(100)
or current_platform.is_device_capability(90))
and not has_flashinfer()):
elif (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
) and not has_flashinfer():
logger.warning_once(
"MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")
"Please `pip install vllm[flashinfer]` for best results."
)
# If FlashInfer is not available, try either Marlin or Triton
if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability(
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
"2.8.0"):
if (
envs.VLLM_MXFP4_USE_MARLIN
or current_platform.get_device_capability()[0] < 9
or not has_triton_kernels()
or not is_torch_equal_or_newer("2.8.0")
):
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
else:
@@ -109,7 +137,6 @@ def get_mxfp4_backend():
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: Optional[list[str]] = None):
super().__init__()
self.ignored_layers = ignored_layers
@@ -134,43 +161,51 @@ class Mxfp4Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return []
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError(
"Mxfp4 attention layer is not implemented")
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.mxfp4_backend = get_mxfp4_backend()
self.max_capture_size = get_current_vllm_config(
).compilation_config.max_capture_size
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_capture_size
)
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
"No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available."
"Please check your environment and try again.")
"Please check your environment and try again."
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
@@ -185,8 +220,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
mxfp4_block = 32
intermediate_size_per_partition_after_pad = \
intermediate_size_per_partition
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
# The moe marlin kernel requires that for each linear
# n % 256 == 0 and k % 128 == 0.
@@ -197,34 +231,44 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# n = hidden_size
# k = intermediate_size_per_partition_after_pad
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
intermediate_size_per_partition, 128
)
hidden_size = round_up(hidden_size, 256)
layer.params_dtype = params_dtype
layer.num_experts = num_experts
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = \
layer.intermediate_size_per_partition = (
intermediate_size_per_partition_after_pad
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
intermediate_size_per_partition, 256
)
hidden_size = round_up(hidden_size, 256)
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
intermediate_size_per_partition, 128
)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
intermediate_size_per_partition, 256
)
hidden_size = round_up(hidden_size, 256)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)
intermediate_size_per_partition, 64
)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
@@ -303,47 +347,61 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer)
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
from flashinfer.fp4_quantization import (
nvfp4_block_scale_interleave)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices)
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_beta = Parameter(torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_clamp_limit = Parameter(torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_beta = Parameter(
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size
assert (layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2)
assert (layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1]
== self.intermediate_size * 2
and layer.w13_weight_scale.shape[2]
== self.hidden_size // sf_block_size)
assert (layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size and
layer.w2_weight.shape[2] == self.intermediate_size // 2)
assert (layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size)
assert (layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
assert (layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size)
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
)
assert (
layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2
)
assert (
layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size
)
w13_weight_scale = layer.w13_weight_scale.data
w2_weight_scale = layer.w2_weight_scale.data
@@ -391,9 +449,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view(
torch.uint8)[permute_indices.to(
w13_weight.device)].contiguous())
gemm1_weights_mxfp4_shuffled.append(
w13_weight[i]
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
.contiguous()
)
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
@@ -402,27 +462,37 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
num_elts_per_sf=16,
)
gemm1_scales_mxfp4_shuffled.append(
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w13_weight_scale.device)].contiguous()))
nvfp4_block_scale_interleave(
w13_weight_scale[i]
.view(torch.uint8)[
permute_sf_indices.to(w13_weight_scale.device)
]
.contiguous()
)
)
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
-1,
1)[permute_bias_indices.to(w13_bias.device)].contiguous())
gemm1_bias_shuffled.append(
w13_bias[i]
.clone()
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
.contiguous()
)
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view(
torch.uint8)[permute_indices.to(
w2_weight.device)].contiguous())
gemm2_weights_mxfp4_shuffled.append(
w2_weight[i]
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
.contiguous()
)
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
@@ -431,81 +501,115 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
num_elts_per_sf=16,
)
gemm2_scales_mxfp4_shuffled.append(
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w2_weight_scale.device)].contiguous()))
nvfp4_block_scale_interleave(
w2_weight_scale[i]
.view(torch.uint8)[
permute_sf_indices.to(w2_weight_scale.device)
]
.contiguous()
)
)
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
gemm2_bias_shuffled.append(
w2_bias[i]
.clone()
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
.contiguous()
)
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = torch.stack(
gemm1_scales_mxfp4_shuffled).reshape(
self.num_experts, 2 * self.intermediate_size,
self.hidden_size // sf_block_size).view(
torch.float8_e4m3fn)
w13_weight_scale = (
torch.stack(gemm1_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
2 * self.intermediate_size,
self.hidden_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
self.num_experts, self.hidden_size, self.intermediate_size //
sf_block_size).view(torch.float8_e4m3fn)
w2_weight_scale = (
torch.stack(gemm2_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
self.hidden_size,
self.intermediate_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(w13_weight_scale,
requires_grad=False)
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = Parameter(w2_weight_scale,
requires_grad=False)
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
layer.w13_bias = Parameter(
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False)
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
self.num_experts, -1),
requires_grad=False)
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_beta = Parameter(torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_clamp_limit = Parameter(torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
requires_grad=False,
)
layer.w2_bias = Parameter(
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False,
)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_beta = Parameter(
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size
# Common shape assertions
assert (layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2)
assert (layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1]
== self.intermediate_size * 2
and layer.w13_weight_scale.shape[2]
== self.hidden_size // sf_block_size)
assert (layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size and
layer.w2_weight.shape[2] == self.intermediate_size // 2)
assert (layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size)
assert (layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
assert (layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size)
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
)
assert (
layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2
)
assert (
layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size
)
# De-interleave and swap for w13 weight, bias, and scales
w13_w = layer.w13_weight.data
@@ -531,51 +635,55 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
orig_shape = w13_scale_swapped.shape
w13_scale_interleaved = block_scale_interleave(
w13_scale_swapped.view(torch.uint8)).reshape(orig_shape)
w13_scale_swapped.view(torch.uint8)
).reshape(orig_shape)
w2_s = layer.w2_weight_scale.data
orig_shape = w2_s.shape
w2_scale_interleaved = block_scale_interleave(
w2_s.view(torch.uint8)).reshape(orig_shape)
w2_s.view(torch.uint8)
).reshape(orig_shape)
layer.w13_weight = Parameter(w13_weight_swapped,
requires_grad=False)
layer.w13_weight_scale = Parameter(w13_scale_interleaved,
requires_grad=False)
layer.w13_bias = Parameter(w13_bias_swapped,
requires_grad=False)
layer.w2_weight_scale = Parameter(w2_scale_interleaved,
requires_grad=False)
layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
layer.w13_weight_scale = Parameter(
w13_scale_interleaved, requires_grad=False
)
layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
layer.w2_weight_scale = Parameter(
w2_scale_interleaved, requires_grad=False
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
def _interleave_mxfp4_cutlass_sm90(w):
w_shape = w.shape
w_interleaved = w.reshape(w_shape[0], w_shape[1],
(w_shape[2] // 4), 4)
w_interleaved = w.reshape(
w_shape[0], w_shape[1], (w_shape[2] // 4), 4
)
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
w_interleaved = w_interleaved.reshape(
w_shape[0], w_shape[2] // 4, w_shape[1] * 4)
w_shape[0], w_shape[2] // 4, w_shape[1] * 4
)
return w_interleaved
w31_scales = w13_scale_swapped.to(torch.uint8).view(
torch.uint8)
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
w31_scales)
w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
w2_weight_scale = layer.w2_weight_scale.data
w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
w2_scales)
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)
layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w],
dim=1),
requires_grad=False)
layer.w13_bias = torch.nn.Parameter(w13_bias_swapped,
requires_grad=False)
layer.w13_weight = torch.nn.Parameter(
torch.cat([w3_w, w1_w], dim=1), requires_grad=False
)
layer.w13_bias = torch.nn.Parameter(
w13_bias_swapped, requires_grad=False
)
layer.w13_weight_scale = torch.nn.Parameter(
w31_scales_interleaved, requires_grad=False)
w31_scales_interleaved, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales_interleaved, requires_grad=False)
w2_scales_interleaved, requires_grad=False
)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
@@ -590,22 +698,25 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = (self.moe.use_pplx_kernels
or self.moe.use_deepep_ll_kernels)
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps)
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps)
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
@@ -644,8 +755,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
@@ -677,14 +788,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
mk.FusedMoEActivationFormat.BatchedExperts):
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP")
"Mxfp4 does not support batched experts format for EP"
)
else:
assert self.moe_quant_config is not None
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
# B200 code-path
kwargs = {
"gemm1_alpha": layer.gemm1_alpha,
@@ -693,36 +809,34 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
elif (self.mxfp4_backend == Mxfp4Backend.MARLIN):
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
return MarlinExperts(self.moe_quant_config)
else:
return OAITritonExperts(self.moe_quant_config)
def _route_and_experts(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
@@ -741,12 +855,17 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
logical_replica_count=logical_replica_count,
)
w13_weight = (self.w13_weight_triton_tensor
if layer.w13_weight is None else layer.w13_weight)
w2_weight = (self.w2_weight_triton_tensor
if layer.w2_weight is None else layer.w2_weight)
w13_weight = (
self.w13_weight_triton_tensor
if layer.w13_weight is None
else layer.w13_weight
)
w2_weight = (
self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight
)
assert all([w is not None for w in [w13_weight, w2_weight]])
return self.fused_experts(
@@ -785,7 +904,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
@@ -824,7 +942,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
)
return torch.ops.vllm.fused_marlin_moe(
x,
@@ -843,28 +962,39 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map)
expert_map=expert_map,
)
assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias,
apply_router_weight_on_input, scoring_func, activation,
expert_load_view, logical_to_physical_map,
logical_replica_count), (
"MXFP4 are not supported with this configuration.")
use_grouped_topk,
topk_group,
num_expert_group,
expert_map,
custom_routing_function,
e_score_correction_bias,
apply_router_weight_on_input,
scoring_func,
activation,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
), "MXFP4 are not supported with this configuration."
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
from flashinfer import trtllm_fp4_block_scale_moe
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*x.shape[:-1], -1)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
@@ -897,8 +1027,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=self.max_capture_size,
)[0]
return trtllm_gen_output
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids, _ = FusedMoE.select_experts(
@@ -916,13 +1048,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts,
device=x.device)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
@@ -934,10 +1064,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(
torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(
torch.long),
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
@@ -978,7 +1106,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward)
triton_kernel_moe_forward,
)
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,