[Kernel][B200] mxfp4 fused cutlass moe (#23696)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -33,33 +34,72 @@ from vllm.utils.flashinfer import has_flashinfer
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _should_use_flashinfer_mxfp4_bf16():
|
||||
"""Determine if FlashInfer MXFP4 BF16 should be used."""
|
||||
# If explicitly set, respect the setting
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
|
||||
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
|
||||
# enum for mxfp4 backend
|
||||
class Mxfp4Backend(Enum):
|
||||
NONE = 0
|
||||
|
||||
# Enable by default on SM100 if MXFP8 is not explicitly enabled
|
||||
if (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
|
||||
logger.info_once(
|
||||
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
|
||||
"For faster performance, consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
|
||||
"though this may impact accuracy.")
|
||||
return True
|
||||
# FlashInfer Backend
|
||||
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
|
||||
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
|
||||
SM100_FI_MXFP4_BF16 = 3
|
||||
SM90_FI_MXFP4_BF16 = 4
|
||||
|
||||
return False
|
||||
# Marlin Backend
|
||||
MARLIN = 5
|
||||
|
||||
# Triton Backend
|
||||
TRITON = 6
|
||||
|
||||
|
||||
def _should_use_flashinfer_mxfp4_mxfp8():
|
||||
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
|
||||
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
def get_mxfp4_backend():
|
||||
# Backend Selection
|
||||
if current_platform.is_cuda():
|
||||
if (current_platform.is_device_capability(90) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
|
||||
return Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
elif (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS):
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
elif (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
|
||||
"for high concurrency throughput workloads consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
|
||||
"performance")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
elif current_platform.is_device_capability(100) and has_flashinfer():
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 BF16 backend for SM100, "
|
||||
"For faster performance on SM100, consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
|
||||
"accuracy.")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
elif ((current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(90))
|
||||
and not has_flashinfer()):
|
||||
logger.warning_once(
|
||||
"MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
|
||||
"is not available. This may result in degraded performance. "
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
|
||||
# If FlashInfer is not available, try either Marlin or Triton
|
||||
if current_platform.get_device_capability(
|
||||
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
|
||||
"2.8.0"):
|
||||
logger.info_once("Using Marlin backend")
|
||||
return Mxfp4Backend.MARLIN
|
||||
else:
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
elif current_platform.is_rocm() and has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
|
||||
def should_use_flashinfer_mxfp4():
|
||||
return (_should_use_flashinfer_mxfp4_mxfp8()
|
||||
or _should_use_flashinfer_mxfp4_bf16())
|
||||
return Mxfp4Backend.NONE
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
@@ -113,31 +153,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
super().__init__(moe)
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.use_marlin = self._should_use_marlin()
|
||||
self.mxfp4_backend = get_mxfp4_backend()
|
||||
self.max_capture_size = get_current_vllm_config(
|
||||
).compilation_config.max_capture_size
|
||||
|
||||
if current_platform.is_device_capability(100) and not has_flashinfer():
|
||||
logger.warning_once(
|
||||
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
|
||||
"is not available. This may result in degraded performance. "
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
|
||||
"No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available."
|
||||
"Please check your environment and try again.")
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
|
||||
def _should_use_marlin(self):
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
||||
return envs.VLLM_MXFP4_USE_MARLIN
|
||||
if current_platform.is_cuda() and \
|
||||
not current_platform.is_device_capability(100):
|
||||
if not current_platform.has_device_capability(90):
|
||||
# marlin kernel has better performance on ampere
|
||||
return True
|
||||
if not has_triton_kernels():
|
||||
return True
|
||||
if not is_torch_equal_or_newer("2.8.0"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@@ -157,7 +181,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
intermediate_size_per_partition_after_pad = \
|
||||
intermediate_size_per_partition
|
||||
if self.use_marlin:
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
# The moe marlin kernel requires that for each linear
|
||||
# n % 256 == 0 and k % 128 == 0.
|
||||
# In gate_up_proj:
|
||||
@@ -175,16 +199,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.hidden_size = hidden_size
|
||||
layer.intermediate_size_per_partition = \
|
||||
intermediate_size_per_partition_after_pad
|
||||
elif should_use_flashinfer_mxfp4():
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
# other padding to increase performance
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif current_platform.is_rocm():
|
||||
elif current_platform.is_rocm() or (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64)
|
||||
@@ -264,9 +292,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.use_marlin:
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
elif should_use_flashinfer_mxfp4():
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
from flashinfer.fp4_quantization import (
|
||||
nvfp4_block_scale_interleave)
|
||||
from flashinfer.fused_moe.core import (
|
||||
@@ -429,7 +458,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
|
||||
self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
else:
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_beta = Parameter(torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_clamp_limit = Parameter(torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
# Common shape assertions
|
||||
assert (layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2)
|
||||
assert (layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2]
|
||||
== self.hidden_size // sf_block_size)
|
||||
assert (layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size and
|
||||
layer.w2_weight.shape[2] == self.intermediate_size // 2)
|
||||
assert (layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size)
|
||||
assert (layer.w13_bias.dim() == 2
|
||||
and layer.w13_bias.shape[0] == self.num_experts
|
||||
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
|
||||
assert (layer.w2_bias.dim() == 2
|
||||
and layer.w2_bias.shape[0] == self.num_experts
|
||||
and layer.w2_bias.shape[1] == self.hidden_size)
|
||||
|
||||
# De-interleave and swap for w13 weight, bias, and scales
|
||||
w13_w = layer.w13_weight.data
|
||||
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
|
||||
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
|
||||
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
w13_b = layer.w13_bias.data.to(torch.float32)
|
||||
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
|
||||
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
|
||||
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
|
||||
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
w13_s = layer.w13_weight_scale.data
|
||||
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
|
||||
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
|
||||
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
|
||||
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
from flashinfer import block_scale_interleave
|
||||
|
||||
orig_shape = w13_scale_swapped.shape
|
||||
w13_scale_interleaved = block_scale_interleave(
|
||||
w13_scale_swapped.view(torch.uint8)).reshape(orig_shape)
|
||||
|
||||
w2_s = layer.w2_weight_scale.data
|
||||
orig_shape = w2_s.shape
|
||||
w2_scale_interleaved = block_scale_interleave(
|
||||
w2_s.view(torch.uint8)).reshape(orig_shape)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight_swapped,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_scale_interleaved,
|
||||
requires_grad=False)
|
||||
layer.w13_bias = Parameter(w13_bias_swapped,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_scale_interleaved,
|
||||
requires_grad=False)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
|
||||
def _interleave_mxfp4_cutlass_sm90(w):
|
||||
w_shape = w.shape
|
||||
w_interleaved = w.reshape(w_shape[0], w_shape[1],
|
||||
(w_shape[2] // 4), 4)
|
||||
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
|
||||
w_interleaved = w_interleaved.reshape(
|
||||
w_shape[0], w_shape[2] // 4, w_shape[1] * 4)
|
||||
return w_interleaved
|
||||
|
||||
w31_scales = w13_scale_swapped.to(torch.uint8).view(
|
||||
torch.uint8)
|
||||
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
|
||||
w31_scales)
|
||||
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
|
||||
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
|
||||
w2_scales)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w],
|
||||
dim=1),
|
||||
requires_grad=False)
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias_swapped,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w31_scales_interleaved, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales_interleaved, requires_grad=False)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
@@ -464,6 +602,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
@@ -500,7 +640,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
if should_use_flashinfer_mxfp4():
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# B200 code-path
|
||||
kwargs = {
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
@@ -601,7 +742,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.use_marlin:
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -665,16 +806,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
logical_replica_count), (
|
||||
"MXFP4 are not supported with this configuration.")
|
||||
|
||||
if should_use_flashinfer_mxfp4():
|
||||
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
|
||||
if _should_use_flashinfer_mxfp4_bf16():
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
from flashinfer import trtllm_fp4_block_scale_moe
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
|
||||
from flashinfer import mxfp8_quantize
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*x.shape[:-1], -1)
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
@@ -706,7 +850,86 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
tune_max_num_tokens=self.max_capture_size,
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
else:
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
# Backend-specific preparation
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, True, 32)
|
||||
|
||||
fake_input_scale = torch.ones(self.num_experts,
|
||||
device=x.device)
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
layer.w2_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
fi_input = x_quant
|
||||
extra_kwargs = dict(
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=x_scale,
|
||||
fc1_expert_weights=layer.w13_weight.contiguous().view(
|
||||
torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.contiguous().view(
|
||||
torch.long),
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
]
|
||||
|
||||
fi_input = x
|
||||
extra_kwargs = dict(
|
||||
use_w4_group_scaling=True,
|
||||
fc1_expert_weights=layer.w13_weight,
|
||||
fc2_expert_weights=layer.w2_weight,
|
||||
)
|
||||
|
||||
output = torch.empty_like(x, dtype=torch.bfloat16)
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=fi_input,
|
||||
token_selected_experts=topk_ids.to(torch.int).contiguous(),
|
||||
token_final_scales=topk_weights,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=output,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=layer.w13_bias,
|
||||
fc2_expert_biases=layer.w2_bias,
|
||||
swiglu_alpha=layer.gemm1_alpha,
|
||||
swiglu_beta=layer.gemm1_beta,
|
||||
swiglu_limit=layer.gemm1_clamp_limit,
|
||||
tp_size=self.moe.tp_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
tune_max_num_tokens=self.max_capture_size,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward)
|
||||
return triton_kernel_moe_forward(
|
||||
@@ -724,3 +947,5 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w2_precision=self.w2_precision_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
Reference in New Issue
Block a user