[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)

This commit is contained in:
Robert Shaw
2026-01-07 19:42:33 -05:00
committed by GitHub
parent ffc0a2798b
commit 5dcd7ef1f2
38 changed files with 1439 additions and 1528 deletions

View File

@@ -73,7 +73,6 @@ if HAS_TRITON:
CutlassExpertsFp8,
CutlassExpertsW4A8Fp8,
cutlass_moe_fp4,
cutlass_moe_fp8,
cutlass_moe_w4a8_fp8,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
@@ -96,7 +95,6 @@ if HAS_TRITON:
"fused_experts",
"get_config_file_name",
"GroupedTopk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8",

View File

@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
e: int,
n: int,
k: int,
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
device: torch.dtype,
):
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
# E: num_experts
# N: intermediate size per partition
# K: hidden dim
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.c_strides2 = ab_strides1_c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config,
)
@property
def activation_formats(
self,
@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
*args,
**kwargs,
):
super().__init__(
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config,
)
super().__init__(*args, **kwargs)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
return (workspace1, workspace2, output)
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert quant_config is not None
if quant_config.a1_scale is not None:
assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1)
if quant_config.w1_scale is not None:
if quant_config.per_out_ch_quant:
assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size(
1
) == w1_q.size(1)
else:
assert (
quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1
)
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
),
)
return fn(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

View File

@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
"""Base class for runtime dispatching of expert implementations."""
def __init__(
self,
experts: mk.FusedMoEPermuteExpertsUnpermute,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
):
super().__init__(experts.quant_config)
self.fallback_experts = fallback_experts
self.experts = experts
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
assert (
self.fallback_experts.activation_formats == self.experts.activation_formats
)
return self.fallback_experts.activation_formats
def supports_chunking(self) -> bool:
assert (
self.experts.supports_chunking()
== self.fallback_experts.supports_chunking()
)
return (
self.experts.supports_chunking()
and self.fallback_experts.supports_chunking()
)
def supports_expert_map(self) -> bool:
assert (
self.experts.supports_expert_map()
== self.fallback_experts.supports_expert_map()
)
return (
self.experts.supports_expert_map()
and self.fallback_experts.supports_expert_map()
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
e_war = self.experts.finalize_weight_and_reduce_impl()
fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl()
is_dge_war = e_war is not None
is_fbe_war = fbe_war is not None
if is_dge_war and is_fbe_war:
assert e_war == fbe_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got e_war: {e_war}, and fbe_war: {fbe_war}"
)
if e_war is not None:
return e_war
assert fbe_war is not None
return fbe_war
@abstractmethod
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
raise NotImplementedError
@abstractmethod
def _select_experts_impl(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise NotImplementedError
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
experts = self._select_experts_impl(hidden_states, w1, w2)
experts.apply(
output,
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
activation,
global_num_experts,
expert_map,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
)

View File

@@ -100,7 +100,7 @@ direct_register_custom_op(
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
def fi_trtllm_fp8_per_tensor_moe(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
@@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
def fi_trtllm_fp8_per_tensor_moe_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
@@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
op_name="fi_trtllm_fp8_per_tensor_moe",
op_func=fi_trtllm_fp8_per_tensor_moe,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,358 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
make_fp8_moe_alpha_scales_for_fi,
prepare_fp8_moe_layer_for_fi,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
prepare_fp8_moe_layer_for_deepgemm,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_group_gemm_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
class Fp8MoeBackend(Enum):
NONE = 0
FLASHINFER_TRTLLM = 1
FLASHINFER_CUTLASS = 2
DEEPGEMM = 3
MARLIN = 4
TRITON = 5
AITER = 6
VLLM_CUTLASS = 7
def select_fp8_moe_backend(
block_quant: bool,
tp_size: int,
with_lora_support: bool,
is_act_and_mul: bool = True,
allow_vllm_cutlass: bool = False,
) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# TODO(rob): in a future PR, we will query each mk for
# supported features and return the mk directly, just like
# we do for the Attention Backend.
if with_lora_support:
return Fp8MoeBackend.TRITON
def _make_log_backend(backend_name: str):
return f"Using {backend_name} backend for FP8 MoE"
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once(_make_log_backend("FlashInfer TRTLLM"))
if not is_act_and_mul:
raise ValueError(
"FlashInfer TRTLLM FP8 MoE backend only supports "
"act_and_mul gate_up_project fusion. Please set "
"VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
"FlashInfer CUTLASS backend instead."
)
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability_family(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
"FlashInfer TRTLLM backend instead."
)
logger.info_once(_make_log_backend("FlashInfer CUTLASS"))
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
if (
current_platform.is_cuda() and not current_platform.has_device_capability(89)
) or envs.VLLM_TEST_FORCE_FP8_MARLIN:
logger.info_once(_make_log_backend("Marlin"), scope="local")
return Fp8MoeBackend.MARLIN
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8:
moe_use_deep_gemm = False
logger.info_once(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
scope="local",
)
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
if not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
if use_deep_gemm and moe_use_deep_gemm and block_quant:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local"
)
elif is_deep_gemm_supported():
logger.info_once(_make_log_backend("DeepGEMM"), scope="local")
return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
return Fp8MoeBackend.AITER
if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported():
logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
return Fp8MoeBackend.VLLM_CUTLASS
# default to Triton
logger.info_once(_make_log_backend("Triton"), scope="local")
return Fp8MoeBackend.TRITON
def convert_to_fp8_moe_kernel_format(
fp8_backend: Fp8MoeBackend,
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block_quant = hasattr(layer, "weight_block_size")
if fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert block_quant
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_deepgemm(
w13,
w2,
w13_scale,
w2_scale,
tuple(layer.weight_block_size),
)
elif fp8_backend == Fp8MoeBackend.AITER:
w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2)
elif fp8_backend == Fp8MoeBackend.MARLIN:
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
layer,
w13,
w2,
w13_scale,
w2_scale,
)
elif fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
is_trtllm=(fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM),
)
return w13, w2, w13_scale, w2_scale
def make_fp8_moe_quant_config(
fp8_backend: Fp8MoeBackend,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig | None:
"""
Create FusedMoEQuantConfig for the specifed FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used
at runtime by the Modular Kernel abstraction.
Note that certain kernels (e.g. Flashinfer CUTLASS) need
special Quant configs to handle non-standard inputs to
their kernel interfaces.
In a future PR, we will have this function should be
a method of the modular kernel itself.
"""
# TRTLLM does not use Modular Kernel abstraction yet.
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN is mixed precision W8A16 config.
if fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
assert a1_scale is not None and a2_scale is not None
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
a1_gscale=(1.0 / a1_scale),
a2_gscale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
def make_fp8_moe_kernel(
layer: torch.nn.Module,
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
) -> tuple[mk.FusedMoEModularKernel, bool]:
# Delayed import is required since the oracle is imported
# by CPU backends which cannot import all of these experts.
# TODO: update the experts to make this not happen.
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
use_inplace = True
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=moe_quant_config.is_block_quantized
),
FlashInferExperts(
out_dtype=layer.orig_dtype,
quant_config=moe_quant_config,
ep_rank=moe_config.ep_rank,
ep_size=moe_config.ep_size,
tp_rank=moe_config.tp_rank,
tp_size=moe_config.tp_size,
use_dp=(moe_config.dp_size > 1),
use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized,
),
)
use_inplace = False
elif fp8_backend == Fp8MoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=moe_quant_config),
)
elif fp8_backend == Fp8MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=moe_quant_config),
)
elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
TritonOrCutlassExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrCutlassExperts(
out_dtype=moe_config.in_dtype,
e=layer.local_num_experts,
n=layer.intermediate_size_per_partition,
k=layer.hidden_size,
device=layer.w13_weight.device,
quant_config=moe_quant_config,
),
)
elif fp8_backend == Fp8MoeBackend.DEEPGEMM:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(quant_config=moe_quant_config),
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
assert fp8_backend == Fp8MoeBackend.TRITON
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config=moe_quant_config),
)
return kernel, use_inplace

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.platforms import current_platform
class TritonOrCutlassExperts(FallbackExperts):
"""Cutlass with fallback to Triton for low latency shapes on SM100."""
def __init__(
self,
e: int,
n: int,
k: int,
out_dtype: torch.dtype | None,
quant_config: FusedMoEQuantConfig,
device: torch.dtype,
):
self.is_sm100 = current_platform.has_device_capability(100)
super().__init__(
experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device),
fallback_experts=TritonExperts(quant_config),
)
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Small batch fallback for sm100.
if self.is_sm100 and M <= 8:
return self.fallback_experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
else:
return self.experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
def _select_experts_impl(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
# Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts
else:
return self.experts

View File

@@ -10,78 +10,22 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm,
_valid_deep_gemm_shape,
)
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(quant_config)
class TritonOrDeepGemmExperts(FallbackExperts):
"""DeepGemm with fallback to Triton for low latency shapes."""
self.triton_expert = TritonExperts(quant_config)
self.allow_deep_gemm = (
allow_deep_gemm
and self.quant_config.use_fp8_w8a8
and self.block_shape == get_mk_alignment_for_contiguous_layout()
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(
experts=DeepGemmExperts(quant_config),
fallback_experts=TritonExperts(quant_config),
)
self.deep_gemm_expert = (
DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None
)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
assert (
self.deep_gemm_expert is None
or self.triton_expert.activation_formats
== self.deep_gemm_expert.activation_formats
)
return self.triton_expert.activation_formats
def supports_chunking(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return (dge is None or dge.supports_chunking()) and (
te is None or te.supports_chunking()
)
def supports_expert_map(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return (dge is None or dge.supports_expert_map()) and (
te is None or te.supports_expert_map()
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
dge = self.deep_gemm_expert
te = self.triton_expert
dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
te_war = te.finalize_weight_and_reduce_impl() if te else None
is_dge_war = dge_war is not None
is_te_war = te_war is not None
if is_dge_war and is_te_war:
assert dge_war == te_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got dge_war: {dge_war}, and te_war: {te_war}"
)
if dge_war is not None:
return dge_war
assert te_war is not None
return te_war
def workspace_shapes(
self,
M: int,
@@ -95,11 +39,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and (
is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)
):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K):
return self.experts.workspace_shapes(
M,
N,
K,
@@ -109,7 +50,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta,
)
else:
return self.triton_expert.workspace_shapes(
return self.fallback_experts.workspace_shapes(
M,
N,
K,
@@ -119,45 +60,13 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta,
)
def apply(
def _select_experts_impl(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
use_deep_gemm = self.allow_deep_gemm and (
is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)
)
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None
experts.apply(
output,
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
activation,
global_num_experts,
expert_map,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
)
) -> mk.FusedMoEPermuteExpertsUnpermute:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts
else:
return self.fallback_experts

View File

@@ -13,10 +13,8 @@ from compressed_tensors.quantization import (
)
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@@ -31,6 +29,7 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config,
@@ -46,11 +45,16 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
@@ -63,8 +67,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
expert_weight_is_col_major,
requant_weight_ue8m0_inplace,
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer,
@@ -76,29 +80,17 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
@@ -657,10 +649,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe: FusedMoEConfig,
layer_name: str | None = None,
):
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
)
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
@@ -687,42 +675,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
and not self.block_quant
self.fp8_backend = select_fp8_moe_backend(
block_quant=self.block_quant,
tp_size=moe.tp_size,
with_lora_support=moe.is_lora_enabled,
# TODO(rob): enable selecting this externally.
allow_vllm_cutlass=True,
)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant
)
self.use_cutlass = not self.block_quant and (
CompressedTensorsConfig._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant
if self.fp8_backend != Fp8MoeBackend.MARLIN:
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
)
if per_act_token != per_channel_quant:
raise NotImplementedError(
"For FP8 Fused MoE layers, per-token and per-channel must be "
"used together."
)
# TODO(rob): hook this up in a follow up PR.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError(
"FlashInfer TRTLLM backend not supported for compressed-tensors yet."
)
or self.is_fp8_w8a8_sm100
)
self.disable_expert_map = False
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
self.allow_deep_gemm = (
self.block_quant
and envs.VLLM_MOE_USE_DEEP_GEMM
and is_deep_gemm_supported()
and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout()
)
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
@@ -880,163 +857,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
# Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w13, w13_scale, w13_input_scale
)
w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w2, w2_scale, w2_input_scale
)
# Per tensor kernels require single activation scale. Use the max.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
assert w13_input_scale is not None and w2_input_scale is not None
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
w13_input_scale, w2_input_scale
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
)
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
# Per-tensor kernels use a single scale, for W13, but on disk there
# is a separate scale for W1 and W3. Requantize with the max scale.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
process_fp8_weight_tensor_strategy_moe(
w13,
w13_scale,
shard_size=layer.intermediate_size_per_partition,
num_experts=layer.num_local_experts,
)
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
elif self.use_marlin:
(
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
input_dtype=self.marlin_input_dtype,
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
)
layer.workspace = workspace
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
if self.use_cutlass:
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full(
(layer.local_num_experts,),
layer.hidden_size,
device=device,
dtype=torch.int64,
)
self.ab_strides2 = torch.full(
(layer.local_num_experts,),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(layer.local_num_experts,),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale):
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale
)
if expert_weight_is_col_major(layer.w2_weight_scale):
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or self.rocm_aiter_moe_enabled:
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
@@ -1048,7 +937,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
assert self.moe_quant_config is not None
if self.use_cutlass:
if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
@@ -1064,26 +953,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
):
logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
self.moe.num_local_experts,
num_dispatchers,
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
max_experts_per_worker=self.moe.num_local_experts,
num_dispatchers=num_dispatchers,
out_dtype=self.moe.in_dtype,
e=layer.local_num_experts,
n=layer.intermediate_size_per_partition,
k=layer.hidden_size,
device=layer.w13_weight.device,
quant_config=self.moe_quant_config,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
out_dtype=self.moe.in_dtype,
e=layer.local_num_experts,
n=layer.intermediate_size_per_partition,
k=layer.hidden_size,
device=layer.w13_weight.device,
quant_config=self.moe_quant_config,
)
# TODO(rob): investigate disable_expert_map
self.disable_expert_map = (
num_dispatchers > 1 or not experts.supports_expert_map()
)
@@ -1096,13 +986,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN]
if (
prepare_finalize.activation_format
@@ -1111,28 +1002,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
if use_deep_gemm and not has_deep_gemm():
raise RuntimeError(
"DeepGEMM requested for MoE layer but not installed."
)
compatible_with_deep_gemm = (
self.moe_quant_config.use_fp8_w8a8
and self.moe_quant_config.block_shape
== get_mk_alignment_for_contiguous_layout()
)
# If this MoE layer is compatible with DeepGEMM, the proper env
# vars are set and DeepGEMM is not installed, throw an error.
if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
raise RuntimeError(
f"MoE layer incompatible with DeepGEMM, expected "
f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
f"or block_shape {self.moe_quant_config.block_shape}"
f"=={get_mk_alignment_for_contiguous_layout()}."
)
if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
@@ -1148,17 +1018,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
)
else:
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
return TritonOrDeepGemmExperts(
self.moe_quant_config,
allow_deep_gemm=use_deep_gemm,
)
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
return TritonOrDeepGemmExperts(self.moe_quant_config)
else:
logger.debug("TritonExperts(%s)", self.__class__.__name__)
return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
return None
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
block_shape=self.weight_block_size,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@@ -1184,118 +1059,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
router_logits=router_logits,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
assert self.kernel is not None
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=None if self.disable_expert_map else layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
if self.use_marlin:
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
elif self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts,
)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
# cutlass path
elif self.use_cutlass:
assert self.moe_quant_config is not None
# small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=None
if self.disable_expert_map
else layer.expert_map, # ???
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8,
)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_config=self.moe_quant_config,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else layer.expert_map,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
return result
@property
def supports_eplb(self) -> bool:

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
import torch
@@ -27,13 +26,17 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -46,25 +49,20 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
deepgemm_post_process_fp8_weight_block,
maybe_post_process_fp8_weight_block,
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
process_fp8_weight_tensor_strategy_moe,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -73,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -81,12 +78,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
all_close_1d,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
@@ -96,11 +91,8 @@ from vllm.model_executor.parameter import (
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -110,107 +102,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
class Fp8MoeBackend(Enum):
NONE = 0
FLASHINFER_TRTLLM = 1
FLASHINFER_CUTLASS = 2
DEEPGEMM = 3
MARLIN = 4
TRITON = 5
AITER = 6
def get_fp8_moe_backend(
block_quant: bool,
moe_parallel_config: FusedMoEParallelConfig,
with_lora_support: bool,
) -> Fp8MoeBackend | None:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if current_platform.is_xpu():
return None
if with_lora_support:
return Fp8MoeBackend.TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability_family(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
if current_platform.is_rocm():
use_marlin = False
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
moe_use_deep_gemm = False
logger.info_once(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
scope="local",
)
# Determine if we should use DeepGEMM (top-level enable switch)
# - If explicitly set by user, respect their choice
# - If not platform supports DeepGEMM, disable it
# This helps avoid warning messages on unsupported platforms.
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
if not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
if use_deep_gemm and moe_use_deep_gemm and block_quant:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local"
)
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
return Fp8MoeBackend.AITER
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@@ -348,7 +239,6 @@ class Fp8Config(QuantizationConfig):
moe_quant_method = Fp8MoEMethod(self, layer)
else:
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
@@ -736,40 +626,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
self.fp8_backend = get_fp8_moe_backend(
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
self.fp8_backend = select_fp8_moe_backend(
block_quant=self.block_quant,
tp_size=layer.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
)
self.marlin_input_dtype = None
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
if self.block_quant and self.weight_block_size != [128, 128]:
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports block "
"size [128, 128]."
)
if not self.block_quant:
if layer.renormalize or layer.custom_routing_function is not None:
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
f"function or renormalization, but got {layer.renormalize} and "
f"{layer.custom_routing_function}."
)
if layer.scoring_func != "sigmoid":
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports "
f"'sigmoid' scoring function, but got {layer.scoring_func}."
)
if layer.activation != "silu":
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
@@ -778,12 +652,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
if self.flashinfer_moe_backend is not None and dynamic_per_token:
if dynamic_per_token and self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
layer: Module,
@@ -907,148 +786,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None
layer.w2_input_scale = None
def _convert_weights_to_kernel_format(
def _setup_kernel(
self,
layer: Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert self.block_quant
w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=w13_weight,
ws=w13_weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=w2_weight,
ws=w2_weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
(
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
input_dtype=self.marlin_input_dtype,
)
layer.workspace = workspace
elif self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13_weight = swap_w13_to_w31(w13_weight)
if self.block_quant:
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
else:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=w13_weight,
w13_input_scale=w13_input_scale,
w2_weight_scale=w2_weight,
w2_input_scale=w2_input_scale,
)
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
# Shuffle weights to runtime format.
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
def _setup_kernel(self, layer: Module) -> None:
"""Setup Modular Kernel for TP Case"""
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
# Flashinfer TRTLLM does not use the modular kernel abstraction.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
# Setup modular kernel for TP case.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.use_inplace = True
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the FlashInferExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
FlashInferExperts(
out_dtype=layer.orig_dtype,
quant_config=self.moe_quant_config,
ep_rank=self.moe.ep_rank,
ep_size=self.moe.ep_size,
tp_rank=self.moe.tp_rank,
tp_size=self.moe.tp_size,
use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant,
),
)
self.use_inplace = False
elif self.fp8_backend == Fp8MoeBackend.AITER:
self.kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=self.moe_quant_config),
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=self.moe_quant_config),
)
else:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
if self.moe_quant_config:
self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
)
def process_weights_after_loading(self, layer: Module) -> None:
@@ -1056,78 +830,58 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return
# Allow for accessing weights and scales in standard way.
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
w13_weight, w13_weight_scale, w13_input_scale
)
w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w13,
w13_scale,
w13_input_scale,
)
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w2_weight, w2_weight_scale, w2_input_scale
w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w2,
w2_scale,
w2_input_scale,
)
# Per tensor kernels require single activation scale. Use the max.
if self.quant_config.activation_scheme == "static":
assert not self.block_quant
assert w13_input_scale is not None and w2_input_scale is not None
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
w13_input_scale, w2_input_scale
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
if not self.block_quant:
shard_size = layer.intermediate_size_per_partition
max_w13_scales = w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
w13_weight[expert_id][start : start + shard_size, :],
w13_weight_scale[expert_id][shard_id],
)
w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
w13_weight_scale = max_w13_scales
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
w13, w13_scale, shard_size, layer.local_num_experts
)
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
# Setup modular kernel for TP case.
self._setup_kernel(layer)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.fp8_backend == Fp8MoeBackend.AITER
or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
if self.fp8_backend in [
Fp8MoeBackend.AITER,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
@@ -1184,7 +938,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
elif self.moe.is_lora_enabled:
return TritonExperts(quant_config=self.moe_quant_config)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(
self.moe,
@@ -1193,17 +947,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
else:
elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__,
self.weight_block_size,
False,
)
return TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
return TritonOrDeepGemmExperts(self.moe_quant_config)
else:
assert self.fp8_backend == Fp8MoeBackend.TRITON
logger.debug(
"TritonExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__,
self.weight_block_size,
False,
)
return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -1212,42 +972,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN uses mixed precision W8A16 config.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
block_shape=self.weight_block_size,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if (
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
and not self.block_quant
):
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
@@ -1269,7 +1000,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
@@ -1308,10 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling=layer.routed_scaling_factor,
)
else:
assert (
not layer.renormalize and layer.custom_routing_function is not None
)
result = apply_flashinfer_per_tensor_scale_fp8(
result = apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
@@ -1327,6 +1055,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
hidden_states=x,
router_logits=router_logits,
)
assert self.kernel is not None
result = self.kernel(
x,
layer.w13_weight,
@@ -1358,7 +1088,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
assert quant_config.weight_block_size is None
assert self.flashinfer_moe_backend is None
def create_weights(
self,
@@ -1447,6 +1176,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
@@ -1457,33 +1188,30 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
layer.w13_weight[expert, :, :]
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
layer.w2_weight[expert, :, :]
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_scale=layer.w2_weight_scale,
w13_input_scale=None,
w2_input_scale=None,
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
)
# Setup modular kernel for TP case.
self._setup_kernel(layer)
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""

View File

@@ -15,7 +15,6 @@ from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
@@ -24,6 +23,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -45,19 +51,16 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
@@ -85,13 +88,12 @@ from vllm.model_executor.parameter import (
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
has_flashinfer_moe,
)
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -721,38 +723,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
) -> None:
super().__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported,
assert self.quant_config.is_checkpoint_fp8_serialized
self.fp8_backend = select_fp8_moe_backend(
block_quant=False,
tp_size=layer.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
)
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
if (
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and not self.moe.is_act_and_mul
):
logger.info_once(
"Non-gated MoE is not supported for min-latency mode,"
"falling back to high-throughput mode"
)
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
)
self.kernel: mk.FusedMoEModularKernel | None = None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
@@ -787,6 +774,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.orig_dtype = params_dtype
layer.num_experts = num_experts
# Use FP8 dtype if checkpoint is serialized
weight_dtype = (
torch.float8_e4m3fn
@@ -826,217 +816,121 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
layer.register_parameter("w2_weight", w2_weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
if self.moe.is_act_and_mul:
w13_weight_scale_shape = (num_experts, 2)
else:
w13_weight_scale_shape = (num_experts, 1)
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
w13_weight_scale_shape,
1.0,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2 if self.moe.is_act_and_mul else 1),
1.0,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Set weight loader attributes for scales
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
def _setup_kernel(
self,
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
):
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
# Setup modular kernel for TP case.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
if self.flashinfer_moe_backend is not None:
self._maybe_pad_intermediate_for_flashinfer(layer)
# Per tensor kernels require single activation scale. Use the max.
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
w13_input_scale, w2_input_scale
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
from vllm._custom_ops import scaled_fp8_quant
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
shard_size = layer.intermediate_size_per_partition
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
w13,
w13_scale,
shard_size,
num_experts=layer.w13_weight.shape[0],
is_act_and_mul=self.moe.is_act_and_mul,
)
# Handle scale parameters
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if (
layer.w13_weight_scale.dim() == 2
and layer.w13_weight_scale.shape[1] == 2
):
assert self.moe.is_act_and_mul, (
"w13_weight_scale should have 2 elements per expert "
"only for gated MoE"
)
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += intermediate_size
# Update the scale parameter to be per-expert
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
else:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
# Input scales must be equal for each expert in fp8 MoE layers.
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
layer.w13_input_scale = Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
layer.w2_input_scale = Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
if self.flashinfer_moe_backend is not None:
if self.moe.is_act_and_mul:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
# NOTE: this adds some attributes used by the trtllm kernel,
# which does not conform to the modular kernels abstraction (yet).
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=layer.w13_weight_scale,
w13_input_scale=layer.w13_input_scale,
w2_weight_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
)
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
return
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts, hidden_size, intermediate = layer.w2_weight.shape
min_alignment = 16
padded_intermediate = round_up(intermediate, min_alignment)
if padded_intermediate == intermediate:
return
logger.info(
"Padding intermediate size from %d to %d for up/down projection weights.",
intermediate,
padded_intermediate,
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
up_mult = 2 if self.moe.is_act_and_mul else 1
padded_gate_up_dim = up_mult * padded_intermediate
# Pad w13 and w12 along its intermediate dimension.
w13 = layer.w13_weight.data
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
padded_w13[:, : w13.shape[1], :] = w13
layer.w13_weight.data = padded_w13
w2 = layer.w2_weight.data
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
padded_w2[:, :, :intermediate] = w2
layer.w2_weight.data = padded_w2
if hasattr(layer, "intermediate_size_per_partition"):
layer.intermediate_size_per_partition = padded_intermediate
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
# TRTLLM does not use modular kernels
return None
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a1_gscale=(1.0 / layer.w13_input_scale),
a2_gscale=(1.0 / layer.w2_input_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
else:
assert self.flashinfer_moe_backend is None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def apply(
self,
@@ -1044,17 +938,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
assert not layer.renormalize
return apply_flashinfer_per_tensor_scale_fp8(
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
@@ -1066,46 +961,34 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
# Expert selection
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
assert layer.activation in ("silu", "relu2_no_mul"), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}"
)
return flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
topk_ids,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
assert self.moe_quant_config is not None
assert self.kernel is not None
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
quant_config=self.moe_quant_config,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod

View File

@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
@@ -315,8 +315,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin:
(workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = (
prepare_moe_fp8_layer_for_marlin(
w13_weight, w2_weight, w13_weight_scale, w2_weight_scale = (
prepare_fp8_moe_layer_for_marlin(
layer,
layer.w13_weight,
layer.w2_weight,
@@ -324,7 +324,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_weight_scale,
)
)
layer.workspace = workspace
# TODO(rob): once we apply refactor to Quark, switch to using
# replace_parameter for compatibility with reloading in RL.
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)

View File

@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
)
def rotate_flashinfer_fp8_moe_weights(
def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
):
"""Shuffle weights for for FI TRT-LLM Format"""
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
epilogue_tile_m = 128
@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights(
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_weight_scale: torch.Tensor,
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_weight_scale,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_weight_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
layer.output2_scales_scalar = g2_alphas
def apply_flashinfer_per_tensor_scale_fp8(
def apply_fi_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8(
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
"FusedMoE flashinfer kernels are only supported for Llama4"
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
is_llama4 = layer.custom_routing_function == Llama4MoE.custom_routing_function
assert is_llama4, "FusedMoE flashinfer kernels are only supported for Llama4"
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl(
)
def flashinfer_cutlass_moe_fp8(
hidden_states: torch.Tensor,
layer: torch.nn.Module,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_deepseek_fp8_block_scale: bool = False,
moe: FusedMoEConfig | None = None,
) -> torch.Tensor:
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
assert quant_config is not None
# Construct modular kernel with block-scale support when requested.
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
),
select_cutlass_fp8_gemm_impl(
moe=moe,
quant_config=quant_config,
out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
),
moe_parallel_config=layer.moe_parallel_config,
)
return fused_experts(
hidden_states,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
FlashinferMoeBackend.TENSORRT_LLM,
)
return backend in backends_supporting_global_sf
def align_fp8_moe_weights_for_fi(
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts, hidden_size, intermediate = w2.shape
min_alignment = 16
padded_intermediate = round_up(intermediate, min_alignment)
if padded_intermediate == intermediate:
return w13, w2, intermediate
logger.info_once(
"Padding intermediate size from %d to %d for up/down projection weights.",
intermediate,
padded_intermediate,
scope="local",
)
up_mult = 2 if is_act_and_mul else 1
padded_gate_up_dim = up_mult * padded_intermediate
# Pad w13 and w2 along its intermediate dimension.
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
padded_w13[:, : w13.shape[1], :] = w13
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
padded_w2[:, :, :intermediate] = w2
return padded_w13, padded_w2, padded_intermediate
def prepare_fp8_moe_layer_for_fi(
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor | None,
is_trtllm: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert Fp8 MoE weights to flashinfer kernel format
Note that for trtllm we update the model state dict
with the scale format needed for these kernels.
Note that for per-tensor, we update the layer's
intermediate size if the weights needed padding.
"""
assert hasattr(layer.moe_config, "is_act_and_mul")
block_quant = (
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
# Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this.
if not block_quant:
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
w13,
w2,
layer.moe_config.is_act_and_mul,
)
layer.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13.
if layer.moe_config.is_act_and_mul:
w13 = swap_w13_to_w31(w13)
if block_quant:
w13_scale = swap_w13_to_w31(w13_scale)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register
# as nn.Parameters since they are not needed for weight-reloading.
if is_trtllm and not block_quant:
assert w13_input_scale is not None
assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
return w13, w2, w13_scale

View File

@@ -21,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
all_close_1d,
per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
@@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block(
return wq, dg_ws
def prepare_fp8_moe_layer_for_deepgemm(
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
block_shape: tuple[int],
):
w13, w13_scale = deepgemm_post_process_fp8_weight_block(
wq=w13,
ws=w13_scale,
quant_block_shape=block_shape,
use_e8m0=is_deep_gemm_e8m0_used(),
)
w2, w2_scale = deepgemm_post_process_fp8_weight_block(
wq=w2,
ws=w2_scale,
quant_block_shape=block_shape,
use_e8m0=is_deep_gemm_e8m0_used(),
)
return w13, w2, w13_scale, w2_scale
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
@@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
replace_parameter(layer, scale_attr, dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
def process_fp8_weight_tensor_strategy_moe(
weight: torch.Tensor,
weight_scales: torch.Tensor,
shard_size: int,
num_experts: int,
is_act_and_mul: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Process moe weights for tensor-wise quantization strategy."""
max_scales = weight_scales.max(dim=1).values
# For w1 case (i.e. not w13): just collapse the last dim since
# there is already just one scale per expert in this case.
if not is_act_and_mul:
assert weight_scales.shape[1] == 1
return weight, weight_scales.max()
# For w13 case (common): require single scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
for expert_id in range(num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
weight[expert_id][start : start + shard_size, :],
weight_scales[expert_id][shard_id],
)
weight[expert_id][start : start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_scales[expert_id]
)
start += shard_size
return weight, max_scales
def process_fp8_input_tensor_strategy_moe(
w13_input_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Process moe input scales for tensor-wise quantization strategy."""
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
logger.info_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
return w13_input_scale.max(), w2_input_scale.max()

View File

@@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8:
return _quant_fp8_method
def get_marlin_input_dtype(prefix):
def get_marlin_input_dtype(prefix: str | None = None):
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":

View File

@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT,
get_marlin_input_dtype,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
@@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin(
replace_parameter(layer, "bias", bias)
def prepare_moe_fp8_layer_for_marlin(
def prepare_fp8_moe_layer_for_marlin(
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
input_dtype: torch.dtype | None = None,
) -> tuple[
torch.Tensor, # workspace
torch.Tensor, # w13_weight
torch.Tensor, # w2_weight
torch.Tensor, # w13_weight_scale
torch.Tensor, # w2_weight_scale
]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Shuffle weights and scales into marlin format.
Note that this function has the side effect of adding a `workspace`
attribute to the layer. This `workspace` does not need to be
registered as a Parameter as it is not used during weight reloading.
"""
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
input_dtype = get_marlin_input_dtype()
if input_dtype is not None and input_dtype.itemsize == 1:
raise NotImplementedError("Marlin W8A8 is not supported.")
@@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin(
# WORKSPACE
device = layer.w13_weight.device
workspace = marlin_make_workspace_new(device, 4)
# NOTE(rob): we do not need to register the workspace as a param
# because it is not used as part of the weight reloading process.
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
@@ -310,13 +315,7 @@ def prepare_moe_fp8_layer_for_marlin(
w13_weight_scale = permute_scales(w13_weight_scale, "w13")
w2_weight_scale = permute_scales(w2_weight_scale, "w2")
return (
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
)
return w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
def pack_fp8_to_int32(