[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
This commit is contained in:
@@ -73,7 +73,6 @@ if HAS_TRITON:
|
||||
CutlassExpertsFp8,
|
||||
CutlassExpertsW4A8Fp8,
|
||||
cutlass_moe_fp4,
|
||||
cutlass_moe_fp8,
|
||||
cutlass_moe_w4a8_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
@@ -96,7 +95,6 @@ if HAS_TRITON:
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"GroupedTopk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
"cutlass_moe_w4a8_fp8",
|
||||
"CutlassExpertsFp8",
|
||||
|
||||
@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
device: torch.dtype,
|
||||
):
|
||||
assert quant_config.use_fp8_w8a8
|
||||
super().__init__(quant_config)
|
||||
|
||||
# E: num_experts
|
||||
# N: intermediate size per partition
|
||||
# K: hidden dim
|
||||
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
|
||||
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides1 = ab_strides1_c_strides2
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = c_strides2
|
||||
self.c_strides2 = ab_strides1_c_strides2
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
|
||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
num_dispatchers: int,
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
quant_config,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
assert max_experts_per_worker > 0
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.num_dispatchers = num_dispatchers
|
||||
@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
activation: str = "silu",
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The token->expert mappings.
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||
every Rank is responsible for a subset of experts. expert_map is a
|
||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||
is -1, it means that this Rank is not responsible for global
|
||||
expert-id i.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is 1.
|
||||
- global_num_experts (int): The total number of experts.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
assert quant_config is not None
|
||||
|
||||
if quant_config.a1_scale is not None:
|
||||
assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1)
|
||||
if quant_config.a2_scale is not None:
|
||||
assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1)
|
||||
|
||||
if quant_config.w1_scale is not None:
|
||||
if quant_config.per_out_ch_quant:
|
||||
assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size(
|
||||
1
|
||||
) == w1_q.size(1)
|
||||
else:
|
||||
assert (
|
||||
quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1
|
||||
)
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
out_dtype=a.dtype,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fn(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
126
vllm/model_executor/layers/fused_moe/fallback.py
Normal file
126
vllm/model_executor/layers/fused_moe/fallback.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
|
||||
|
||||
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
"""Base class for runtime dispatching of expert implementations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
):
|
||||
super().__init__(experts.quant_config)
|
||||
self.fallback_experts = fallback_experts
|
||||
self.experts = experts
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
assert (
|
||||
self.fallback_experts.activation_formats == self.experts.activation_formats
|
||||
)
|
||||
return self.fallback_experts.activation_formats
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
assert (
|
||||
self.experts.supports_chunking()
|
||||
== self.fallback_experts.supports_chunking()
|
||||
)
|
||||
return (
|
||||
self.experts.supports_chunking()
|
||||
and self.fallback_experts.supports_chunking()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
assert (
|
||||
self.experts.supports_expert_map()
|
||||
== self.fallback_experts.supports_expert_map()
|
||||
)
|
||||
return (
|
||||
self.experts.supports_expert_map()
|
||||
and self.fallback_experts.supports_expert_map()
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
e_war = self.experts.finalize_weight_and_reduce_impl()
|
||||
fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl()
|
||||
is_dge_war = e_war is not None
|
||||
is_fbe_war = fbe_war is not None
|
||||
|
||||
if is_dge_war and is_fbe_war:
|
||||
assert e_war == fbe_war, (
|
||||
"Both implementations should agree on WeightAndReduce impls. "
|
||||
f"Got e_war: {e_war}, and fbe_war: {fbe_war}"
|
||||
)
|
||||
|
||||
if e_war is not None:
|
||||
return e_war
|
||||
assert fbe_war is not None
|
||||
return fbe_war
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _select_experts_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
experts = self._select_experts_impl(hidden_states, w1, w2)
|
||||
experts.apply(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
def fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
def fi_trtllm_fp8_per_tensor_moe_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
||||
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
||||
op_name="fi_trtllm_fp8_per_tensor_moe",
|
||||
op_func=fi_trtllm_fp8_per_tensor_moe,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
||||
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
2
vllm/model_executor/layers/fused_moe/oracle/__init__.py
Normal file
2
vllm/model_executor/layers/fused_moe/oracle/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
358
vllm/model_executor/layers/fused_moe/oracle/fp8.py
Normal file
358
vllm/model_executor/layers/fused_moe/oracle/fp8.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
prepare_fp8_moe_layer_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
prepare_fp8_moe_layer_for_deepgemm,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_group_gemm_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Fp8MoeBackend(Enum):
|
||||
NONE = 0
|
||||
FLASHINFER_TRTLLM = 1
|
||||
FLASHINFER_CUTLASS = 2
|
||||
DEEPGEMM = 3
|
||||
MARLIN = 4
|
||||
TRITON = 5
|
||||
AITER = 6
|
||||
VLLM_CUTLASS = 7
|
||||
|
||||
|
||||
def select_fp8_moe_backend(
|
||||
block_quant: bool,
|
||||
tp_size: int,
|
||||
with_lora_support: bool,
|
||||
is_act_and_mul: bool = True,
|
||||
allow_vllm_cutlass: bool = False,
|
||||
) -> Fp8MoeBackend:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
# TODO(rob): in a future PR, we will query each mk for
|
||||
# supported features and return the mk directly, just like
|
||||
# we do for the Attention Backend.
|
||||
|
||||
if with_lora_support:
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
def _make_log_backend(backend_name: str):
|
||||
return f"Using {backend_name} backend for FP8 MoE"
|
||||
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
):
|
||||
backend = get_flashinfer_moe_backend()
|
||||
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
logger.info_once(_make_log_backend("FlashInfer TRTLLM"))
|
||||
if not is_act_and_mul:
|
||||
raise ValueError(
|
||||
"FlashInfer TRTLLM FP8 MoE backend only supports "
|
||||
"act_and_mul gate_up_project fusion. Please set "
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
|
||||
"FlashInfer CUTLASS backend instead."
|
||||
)
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant and current_platform.is_device_capability_family(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization on SM100. Please use "
|
||||
"VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
|
||||
"FlashInfer TRTLLM backend instead."
|
||||
)
|
||||
logger.info_once(_make_log_backend("FlashInfer CUTLASS"))
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
if (
|
||||
current_platform.is_cuda() and not current_platform.has_device_capability(89)
|
||||
) or envs.VLLM_TEST_FORCE_FP8_MARLIN:
|
||||
logger.info_once(_make_log_backend("Marlin"), scope="local")
|
||||
return Fp8MoeBackend.MARLIN
|
||||
|
||||
# Determine if we should use DeepGEMM with block-quantized weights:
|
||||
# - If explicitly set by user, respect their choice
|
||||
# - If not explicitly set (default), disable when TP size is >= 8
|
||||
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8:
|
||||
moe_use_deep_gemm = False
|
||||
logger.info_once(
|
||||
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
|
||||
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
|
||||
if not is_deep_gemm_supported():
|
||||
use_deep_gemm = False
|
||||
logger.info_once(
|
||||
"DeepGEMM is disabled because the platform does not support it.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
if use_deep_gemm and moe_use_deep_gemm and block_quant:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once(
|
||||
"DeepGEMM backend requested but not available.", scope="local"
|
||||
)
|
||||
elif is_deep_gemm_supported():
|
||||
logger.info_once(_make_log_backend("DeepGEMM"), scope="local")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
|
||||
logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
|
||||
return Fp8MoeBackend.AITER
|
||||
|
||||
if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported():
|
||||
logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
|
||||
return Fp8MoeBackend.VLLM_CUTLASS
|
||||
|
||||
# default to Triton
|
||||
logger.info_once(_make_log_backend("Triton"), scope="local")
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
|
||||
def convert_to_fp8_moe_kernel_format(
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
block_quant = hasattr(layer, "weight_block_size")
|
||||
if fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
assert block_quant
|
||||
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_deepgemm(
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
tuple(layer.weight_block_size),
|
||||
)
|
||||
elif fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2)
|
||||
elif fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
)
|
||||
elif fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w2=w2,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
is_trtllm=(fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM),
|
||||
)
|
||||
|
||||
return w13, w2, w13_scale, w2_scale
|
||||
|
||||
|
||||
def make_fp8_moe_quant_config(
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
"""
|
||||
Create FusedMoEQuantConfig for the specifed FP8 Backend.
|
||||
The FusedMoEQuantConfig holds the scales that are used
|
||||
at runtime by the Modular Kernel abstraction.
|
||||
|
||||
Note that certain kernels (e.g. Flashinfer CUTLASS) need
|
||||
special Quant configs to handle non-standard inputs to
|
||||
their kernel interfaces.
|
||||
|
||||
In a future PR, we will have this function should be
|
||||
a method of the modular kernel itself.
|
||||
"""
|
||||
# TRTLLM does not use Modular Kernel abstraction yet.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN is mixed precision W8A16 config.
|
||||
if fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Flashinfer CUTLASS per-tensor uses single dq scale
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
|
||||
assert a1_scale is not None and a2_scale is not None
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
a1_gscale=(1.0 / a1_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
)
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
def make_fp8_moe_kernel(
|
||||
layer: torch.nn.Module,
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
) -> tuple[mk.FusedMoEModularKernel, bool]:
|
||||
# Delayed import is required since the oracle is imported
|
||||
# by CPU backends which cannot import all of these experts.
|
||||
# TODO: update the experts to make this not happen.
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
|
||||
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||
# all of the kernels in the TP case to use mk. Once this is
|
||||
# done, then we will initialzie the TP case and DP/EP case
|
||||
# via the same code path (i.e. via maybe_init_modular_kernel).
|
||||
# NOTE(rob): in progress migrating all into this format.
|
||||
use_inplace = True
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(
|
||||
defer_input_quant=moe_quant_config.is_block_quantized
|
||||
),
|
||||
FlashInferExperts(
|
||||
out_dtype=layer.orig_dtype,
|
||||
quant_config=moe_quant_config,
|
||||
ep_rank=moe_config.ep_rank,
|
||||
ep_size=moe_config.ep_size,
|
||||
tp_rank=moe_config.tp_rank,
|
||||
tp_size=moe_config.tp_size,
|
||||
use_dp=(moe_config.dp_size > 1),
|
||||
use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized,
|
||||
),
|
||||
)
|
||||
use_inplace = False
|
||||
|
||||
elif fp8_backend == Fp8MoeBackend.AITER:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
# TODO: make defer_input_quant an attr of the AiterExperts
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
AiterExperts(quant_config=moe_quant_config),
|
||||
)
|
||||
elif fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config=moe_quant_config),
|
||||
)
|
||||
elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
|
||||
TritonOrCutlassExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrCutlassExperts(
|
||||
out_dtype=moe_config.in_dtype,
|
||||
e=layer.local_num_experts,
|
||||
n=layer.intermediate_size_per_partition,
|
||||
k=layer.hidden_size,
|
||||
device=layer.w13_weight.device,
|
||||
quant_config=moe_quant_config,
|
||||
),
|
||||
)
|
||||
elif fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(quant_config=moe_quant_config),
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
)
|
||||
|
||||
assert fp8_backend == Fp8MoeBackend.TRITON
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(quant_config=moe_quant_config),
|
||||
)
|
||||
return kernel, use_inplace
|
||||
75
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
Normal file
75
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TritonOrCutlassExperts(FallbackExperts):
|
||||
"""Cutlass with fallback to Triton for low latency shapes on SM100."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
out_dtype: torch.dtype | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
device: torch.dtype,
|
||||
):
|
||||
self.is_sm100 = current_platform.has_device_capability(100)
|
||||
super().__init__(
|
||||
experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device),
|
||||
fallback_experts=TritonExperts(quant_config),
|
||||
)
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Small batch fallback for sm100.
|
||||
if self.is_sm100 and M <= 8:
|
||||
return self.fallback_experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
else:
|
||||
return self.experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
|
||||
def _select_experts_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
# Small batch fallback for sm100.
|
||||
if self.is_sm100 and hidden_states.shape[0] <= 8:
|
||||
return self.fallback_experts
|
||||
else:
|
||||
return self.experts
|
||||
@@ -10,78 +10,22 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm,
|
||||
_valid_deep_gemm_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
|
||||
|
||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
allow_deep_gemm: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
"""DeepGemm with fallback to Triton for low latency shapes."""
|
||||
|
||||
self.triton_expert = TritonExperts(quant_config)
|
||||
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8
|
||||
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(
|
||||
experts=DeepGemmExperts(quant_config),
|
||||
fallback_experts=TritonExperts(quant_config),
|
||||
)
|
||||
|
||||
self.deep_gemm_expert = (
|
||||
DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None
|
||||
)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
assert (
|
||||
self.deep_gemm_expert is None
|
||||
or self.triton_expert.activation_formats
|
||||
== self.deep_gemm_expert.activation_formats
|
||||
)
|
||||
return self.triton_expert.activation_formats
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
dge = self.deep_gemm_expert
|
||||
te = self.triton_expert
|
||||
return (dge is None or dge.supports_chunking()) and (
|
||||
te is None or te.supports_chunking()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
dge = self.deep_gemm_expert
|
||||
te = self.triton_expert
|
||||
return (dge is None or dge.supports_expert_map()) and (
|
||||
te is None or te.supports_expert_map()
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
dge = self.deep_gemm_expert
|
||||
te = self.triton_expert
|
||||
dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
|
||||
te_war = te.finalize_weight_and_reduce_impl() if te else None
|
||||
is_dge_war = dge_war is not None
|
||||
is_te_war = te_war is not None
|
||||
|
||||
if is_dge_war and is_te_war:
|
||||
assert dge_war == te_war, (
|
||||
"Both implementations should agree on WeightAndReduce impls. "
|
||||
f"Got dge_war: {dge_war}, and te_war: {te_war}"
|
||||
)
|
||||
|
||||
if dge_war is not None:
|
||||
return dge_war
|
||||
|
||||
assert te_war is not None
|
||||
return te_war
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
@@ -95,11 +39,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and (
|
||||
is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)
|
||||
):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K):
|
||||
return self.experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -109,7 +50,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta,
|
||||
)
|
||||
else:
|
||||
return self.triton_expert.workspace_shapes(
|
||||
return self.fallback_experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -119,45 +60,13 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta,
|
||||
)
|
||||
|
||||
def apply(
|
||||
def _select_experts_impl(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
use_deep_gemm = self.allow_deep_gemm and (
|
||||
is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)
|
||||
)
|
||||
|
||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||
assert experts is not None
|
||||
|
||||
experts.apply(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
|
||||
return self.experts
|
||||
else:
|
||||
return self.fallback_experts
|
||||
|
||||
@@ -13,10 +13,8 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
@@ -31,6 +29,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
int4_w4a16_moe_quant_config,
|
||||
int4_w4afp8_moe_quant_config,
|
||||
int8_w8a8_moe_quant_config,
|
||||
@@ -46,11 +45,16 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_TYPES_MAP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
@@ -63,8 +67,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
expert_weight_is_col_major,
|
||||
requant_weight_ue8m0_inplace,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer,
|
||||
@@ -76,29 +80,17 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_bf16_scales_to_fp8,
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -657,10 +649,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
@@ -687,42 +675,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||
"channelwise, dynamic per token quantization."
|
||||
)
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
and not self.block_quant
|
||||
self.fp8_backend = select_fp8_moe_backend(
|
||||
block_quant=self.block_quant,
|
||||
tp_size=moe.tp_size,
|
||||
with_lora_support=moe.is_lora_enabled,
|
||||
# TODO(rob): enable selecting this externally.
|
||||
allow_vllm_cutlass=True,
|
||||
)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# cutlass path
|
||||
self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
|
||||
self.weight_quant, self.input_quant
|
||||
)
|
||||
self.use_cutlass = not self.block_quant and (
|
||||
CompressedTensorsConfig._is_fp8_w8a8_sm90(
|
||||
self.weight_quant, self.input_quant
|
||||
if self.fp8_backend != Fp8MoeBackend.MARLIN:
|
||||
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
per_channel_quant = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
)
|
||||
if per_act_token != per_channel_quant:
|
||||
raise NotImplementedError(
|
||||
"For FP8 Fused MoE layers, per-token and per-channel must be "
|
||||
"used together."
|
||||
)
|
||||
# TODO(rob): hook this up in a follow up PR.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer TRTLLM backend not supported for compressed-tensors yet."
|
||||
)
|
||||
or self.is_fp8_w8a8_sm100
|
||||
)
|
||||
self.disable_expert_map = False
|
||||
self.layer_name = layer_name
|
||||
self.marlin_input_dtype = (
|
||||
get_marlin_input_dtype(layer_name) if self.use_marlin else None
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = (
|
||||
self.block_quant
|
||||
and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
and is_deep_gemm_supported()
|
||||
and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -880,163 +857,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
# Allow for accessing weights and scales in standard way.
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
w13_input_scale = layer.w13_input_scale
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
|
||||
if current_platform.is_fp8_fnuz():
|
||||
w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w13, w13_scale, w13_input_scale
|
||||
)
|
||||
w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w2, w2_scale, w2_input_scale
|
||||
)
|
||||
|
||||
# Per tensor kernels require single activation scale. Use the max.
|
||||
if self.static_input_scales:
|
||||
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
assert w13_input_scale is not None and w2_input_scale is not None
|
||||
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
|
||||
w13_input_scale, w2_input_scale
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
|
||||
)
|
||||
)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
|
||||
# for w13 per expert. Use max then dequant and requant each expert.
|
||||
# Per-tensor kernels use a single scale, for W13, but on disk there
|
||||
# is a separate scale for W1 and W3. Requantize with the max scale.
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
start += shard_size
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
process_fp8_weight_tensor_strategy_moe(
|
||||
w13,
|
||||
w13_scale,
|
||||
shard_size=layer.intermediate_size_per_partition,
|
||||
num_experts=layer.num_local_experts,
|
||||
)
|
||||
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
|
||||
fp8_backend=self.fp8_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w2=w2,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
# Replace parameters with updated versions. Note that this helper
|
||||
# function ensures the replacement is compatible with RL weight reloads.
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_scale)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_scale)
|
||||
|
||||
elif self.use_marlin:
|
||||
(
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
) = prepare_moe_fp8_layer_for_marlin(
|
||||
layer,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
layer=layer,
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
)
|
||||
layer.workspace = workspace
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||
|
||||
if self.use_cutlass:
|
||||
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
|
||||
device = layer.w13_weight.device
|
||||
# ab_strides1 and c_strides2 are the same
|
||||
self.ab_strides1_c_strides2 = torch.full(
|
||||
(layer.local_num_experts,),
|
||||
layer.hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.ab_strides2 = torch.full(
|
||||
(layer.local_num_experts,),
|
||||
layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.c_strides1 = torch.full(
|
||||
(layer.local_num_experts,),
|
||||
2 * layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
if is_deep_gemm_e8m0_used() and self.block_quant:
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w13_weight.data,
|
||||
layer.w13_weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w2_weight.data,
|
||||
layer.w2_weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale):
|
||||
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale
|
||||
)
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale):
|
||||
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
@@ -1048,7 +937,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# cutlass path
|
||||
assert self.moe_quant_config is not None
|
||||
if self.use_cutlass:
|
||||
if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
@@ -1064,26 +953,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
):
|
||||
logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__)
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
self.moe.num_local_experts,
|
||||
num_dispatchers,
|
||||
self.moe.in_dtype,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
max_experts_per_worker=self.moe.num_local_experts,
|
||||
num_dispatchers=num_dispatchers,
|
||||
out_dtype=self.moe.in_dtype,
|
||||
e=layer.local_num_experts,
|
||||
n=layer.intermediate_size_per_partition,
|
||||
k=layer.hidden_size,
|
||||
device=layer.w13_weight.device,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||
experts = CutlassExpertsFp8(
|
||||
self.moe.in_dtype,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
out_dtype=self.moe.in_dtype,
|
||||
e=layer.local_num_experts,
|
||||
n=layer.intermediate_size_per_partition,
|
||||
k=layer.hidden_size,
|
||||
device=layer.w13_weight.device,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
# TODO(rob): investigate disable_expert_map
|
||||
self.disable_expert_map = (
|
||||
num_dispatchers > 1 or not experts.supports_expert_map()
|
||||
)
|
||||
@@ -1096,13 +986,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
||||
|
||||
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN]
|
||||
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
@@ -1111,28 +1002,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
if use_deep_gemm and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
"DeepGEMM requested for MoE layer but not installed."
|
||||
)
|
||||
|
||||
compatible_with_deep_gemm = (
|
||||
self.moe_quant_config.use_fp8_w8a8
|
||||
and self.moe_quant_config.block_shape
|
||||
== get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
# If this MoE layer is compatible with DeepGEMM, the proper env
|
||||
# vars are set and DeepGEMM is not installed, throw an error.
|
||||
if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
f"MoE layer incompatible with DeepGEMM, expected "
|
||||
f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
|
||||
f"or block_shape {self.moe_quant_config.block_shape}"
|
||||
f"=={get_mk_alignment_for_contiguous_layout()}."
|
||||
)
|
||||
|
||||
if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
|
||||
return BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
@@ -1148,17 +1018,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
|
||||
return TritonOrDeepGemmExperts(
|
||||
self.moe_quant_config,
|
||||
allow_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
|
||||
return TritonOrDeepGemmExperts(self.moe_quant_config)
|
||||
else:
|
||||
logger.debug("TritonExperts(%s)", self.__class__.__name__)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
@@ -1184,118 +1059,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
assert self.kernel is not None
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
# TODO(rob): investigate the disable_expert_map introduced by:
|
||||
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
elif self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
assert per_act_token == per_channel_quant
|
||||
assert self.moe_quant_config is not None
|
||||
return rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
# cutlass path
|
||||
elif self.use_cutlass:
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
# small-batch fallback on SM100
|
||||
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert per_act_token == per_channel_quant
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=None
|
||||
if self.disable_expert_map
|
||||
else layer.expert_map, # ???
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8,
|
||||
)
|
||||
|
||||
assert per_act_token == per_channel_quant
|
||||
assert self.moe_quant_config is not None
|
||||
return cutlass_moe_fp8(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
)
|
||||
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert per_act_token == per_channel_quant
|
||||
assert self.moe_quant_config is not None
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
return result
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -27,13 +26,17 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -46,25 +49,20 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
get_flashinfer_moe_backend,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
deepgemm_post_process_fp8_weight_block,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@@ -73,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@@ -81,12 +78,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
all_close_1d,
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
@@ -96,11 +91,8 @@ from vllm.model_executor.parameter import (
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -110,107 +102,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Fp8MoeBackend(Enum):
|
||||
NONE = 0
|
||||
FLASHINFER_TRTLLM = 1
|
||||
FLASHINFER_CUTLASS = 2
|
||||
DEEPGEMM = 3
|
||||
MARLIN = 4
|
||||
TRITON = 5
|
||||
AITER = 6
|
||||
|
||||
|
||||
def get_fp8_moe_backend(
|
||||
block_quant: bool,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
with_lora_support: bool,
|
||||
) -> Fp8MoeBackend | None:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
if current_platform.is_xpu():
|
||||
return None
|
||||
if with_lora_support:
|
||||
return Fp8MoeBackend.TRITON
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
):
|
||||
backend = get_flashinfer_moe_backend()
|
||||
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant and current_platform.is_device_capability_family(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization on SM100. Please use "
|
||||
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"instead."
|
||||
)
|
||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||
return Fp8MoeBackend.MARLIN
|
||||
|
||||
# Determine if we should use DeepGEMM with block-quantized weights:
|
||||
# - If explicitly set by user, respect their choice
|
||||
# - If not explicitly set (default), disable when TP size is >= 8
|
||||
moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
|
||||
moe_use_deep_gemm = False
|
||||
logger.info_once(
|
||||
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
|
||||
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
# Determine if we should use DeepGEMM (top-level enable switch)
|
||||
# - If explicitly set by user, respect their choice
|
||||
# - If not platform supports DeepGEMM, disable it
|
||||
# This helps avoid warning messages on unsupported platforms.
|
||||
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
|
||||
if not is_deep_gemm_supported():
|
||||
use_deep_gemm = False
|
||||
logger.info_once(
|
||||
"DeepGEMM is disabled because the platform does not support it.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
if use_deep_gemm and moe_use_deep_gemm and block_quant:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once(
|
||||
"DeepGEMM backend requested but not available.", scope="local"
|
||||
)
|
||||
elif is_deep_gemm_supported():
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
|
||||
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
|
||||
return Fp8MoeBackend.AITER
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@@ -348,7 +239,6 @@ class Fp8Config(QuantizationConfig):
|
||||
moe_quant_method = Fp8MoEMethod(self, layer)
|
||||
else:
|
||||
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
|
||||
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
@@ -736,40 +626,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
self.weight_scale_name = (
|
||||
"weight_scale_inv" if self.block_quant else "weight_scale"
|
||||
)
|
||||
self.fp8_backend = get_fp8_moe_backend(
|
||||
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
|
||||
self.fp8_backend = select_fp8_moe_backend(
|
||||
block_quant=self.block_quant,
|
||||
tp_size=layer.moe_parallel_config.tp_size,
|
||||
with_lora_support=self.moe.is_lora_enabled,
|
||||
)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
if self.block_quant and self.weight_block_size != [128, 128]:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports block "
|
||||
"size [128, 128]."
|
||||
)
|
||||
if not self.block_quant:
|
||||
if layer.renormalize or layer.custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
|
||||
f"function or renormalization, but got {layer.renormalize} and "
|
||||
f"{layer.custom_routing_function}."
|
||||
)
|
||||
if layer.scoring_func != "sigmoid":
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports "
|
||||
f"'sigmoid' scoring function, but got {layer.scoring_func}."
|
||||
)
|
||||
if layer.activation != "silu":
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
|
||||
@@ -778,12 +652,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
dynamic_per_token = (
|
||||
not self.block_quant and self.quant_config.activation_scheme != "static"
|
||||
)
|
||||
if self.flashinfer_moe_backend is not None and dynamic_per_token:
|
||||
if dynamic_per_token and self.fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer FP8 MoE backend does not support dynamic per token "
|
||||
"activation quantization."
|
||||
)
|
||||
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
@@ -907,148 +786,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def _convert_weights_to_kernel_format(
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: Module,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
) -> None:
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
assert self.block_quant
|
||||
w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=w13_weight,
|
||||
ws=w13_weight_scale,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=w2_weight,
|
||||
ws=w2_weight_scale,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
(
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
) = prepare_moe_fp8_layer_for_marlin(
|
||||
layer,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
layer.workspace = workspace
|
||||
|
||||
elif self.fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
w13_weight = swap_w13_to_w31(w13_weight)
|
||||
if self.block_quant:
|
||||
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
|
||||
else:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
w13_weight_scale=w13_weight,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_weight_scale=w2_weight,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
)
|
||||
# Shuffle weights to runtime format.
|
||||
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
|
||||
fp8_backend=self.fp8_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w2=w2,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Replace parameters with updated versions. Note that this helper
|
||||
# function ensures the replacement is compatible with RL weight reloads.
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
|
||||
|
||||
def _setup_kernel(self, layer: Module) -> None:
|
||||
"""Setup Modular Kernel for TP Case"""
|
||||
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||
# all of the kernels in the TP case to use mk. Once this is
|
||||
# done, then we will initialzie the TP case and DP/EP case
|
||||
# via the same code path (i.e. via maybe_init_modular_kernel).
|
||||
# NOTE(rob): in progress migrating all into this format.
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
# Flashinfer TRTLLM does not use the modular kernel abstraction.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
self.use_inplace = True
|
||||
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO: make defer_input_quant an attr of the FlashInferExperts
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
|
||||
FlashInferExperts(
|
||||
out_dtype=layer.orig_dtype,
|
||||
quant_config=self.moe_quant_config,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
tp_size=self.moe.tp_size,
|
||||
use_dp=(self.moe.dp_size > 1),
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
),
|
||||
)
|
||||
self.use_inplace = False
|
||||
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
# TODO: make defer_input_quant an attr of the AiterExperts
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
AiterExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config=self.moe_quant_config),
|
||||
)
|
||||
else:
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
),
|
||||
if self.moe_quant_config:
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
layer=layer,
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
@@ -1056,78 +830,58 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
return
|
||||
|
||||
# Allow for accessing weights and scales in standard way.
|
||||
w13_weight = layer.w13_weight
|
||||
w2_weight = layer.w2_weight
|
||||
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
w13_input_scale = layer.w13_input_scale
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
|
||||
if current_platform.is_fp8_fnuz():
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
w13_weight, w13_weight_scale, w13_input_scale
|
||||
)
|
||||
w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_input_scale,
|
||||
)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w2_weight, w2_weight_scale, w2_input_scale
|
||||
w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_input_scale,
|
||||
)
|
||||
|
||||
# Per tensor kernels require single activation scale. Use the max.
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
assert not self.block_quant
|
||||
assert w13_input_scale is not None and w2_input_scale is not None
|
||||
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
|
||||
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
|
||||
w13_input_scale, w2_input_scale
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||
|
||||
# Per tensor kernels require single weight scale for w13 per expert, but
|
||||
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
||||
if not self.block_quant:
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
w13_weight[expert_id][start : start + shard_size, :],
|
||||
w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
start += shard_size
|
||||
w13_weight_scale = max_w13_scales
|
||||
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
|
||||
w13, w13_scale, shard_size, layer.local_num_experts
|
||||
)
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
# Shuffle weights to runtime format and setup kernel.
|
||||
self._setup_kernel(
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
self._setup_kernel(layer)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.fp8_backend == Fp8MoeBackend.AITER
|
||||
or self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
if self.fp8_backend in [
|
||||
Fp8MoeBackend.AITER,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
return None
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
@@ -1184,7 +938,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
elif self.moe.is_lora_enabled:
|
||||
return TritonExperts(quant_config=self.moe_quant_config)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# Select GEMM experts with block-scale when weights are block-quantized
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
self.moe,
|
||||
@@ -1193,17 +947,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
else:
|
||||
elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
logger.debug(
|
||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
return TritonOrDeepGemmExperts(self.moe_quant_config)
|
||||
else:
|
||||
assert self.fp8_backend == Fp8MoeBackend.TRITON
|
||||
logger.debug(
|
||||
"TritonExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -1212,42 +972,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN uses mixed precision W8A16 config.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
# Flashinfer CUTLASS per-tensor uses single dq scale
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if (
|
||||
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
and not self.block_quant
|
||||
):
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
)
|
||||
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
return make_fp8_moe_quant_config(
|
||||
fp8_backend=self.fp8_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
@@ -1269,7 +1000,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
# TODO(rob): convert this to MK.
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
@@ -1308,10 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not layer.renormalize and layer.custom_routing_function is not None
|
||||
)
|
||||
result = apply_flashinfer_per_tensor_scale_fp8(
|
||||
result = apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -1327,6 +1055,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
assert self.kernel is not None
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -1358,7 +1088,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
assert not quant_config.is_checkpoint_fp8_serialized
|
||||
assert quant_config.activation_scheme == "dynamic"
|
||||
assert quant_config.weight_block_size is None
|
||||
assert self.flashinfer_moe_backend is None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1447,6 +1176,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
@@ -1457,33 +1188,30 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
|
||||
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
|
||||
layer.w13_weight[expert, :, :]
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
|
||||
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight[expert, :, :]
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
w13_input_scale=None,
|
||||
w2_input_scale=None,
|
||||
# Shuffle weights to runtime format and setup kernel.
|
||||
self._setup_kernel(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
self._setup_kernel(layer)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,6 @@ from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
@@ -24,6 +23,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -45,19 +51,16 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
get_flashinfer_moe_backend,
|
||||
is_flashinfer_supporting_global_sf,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
@@ -85,13 +88,12 @@ from vllm.model_executor.parameter import (
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_scaled_fp4_mm,
|
||||
has_flashinfer,
|
||||
has_flashinfer_moe,
|
||||
)
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -721,38 +723,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
) -> None:
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
assert self.quant_config.is_checkpoint_fp8_serialized
|
||||
self.fp8_backend = select_fp8_moe_backend(
|
||||
block_quant=False,
|
||||
tp_size=layer.moe_parallel_config.tp_size,
|
||||
with_lora_support=self.moe.is_lora_enabled,
|
||||
)
|
||||
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
if (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
and not self.moe.is_act_and_mul
|
||||
):
|
||||
logger.info_once(
|
||||
"Non-gated MoE is not supported for min-latency mode,"
|
||||
"falling back to high-throughput mode"
|
||||
)
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
# TRT LLM not supported with all2all yet.
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# TP case: avoid convert to ModularKernelMethod - to be refactored.
|
||||
if self.moe.dp_size == 1:
|
||||
return None
|
||||
@@ -787,6 +774,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.num_experts = num_experts
|
||||
|
||||
# Use FP8 dtype if checkpoint is serialized
|
||||
weight_dtype = (
|
||||
torch.float8_e4m3fn
|
||||
@@ -826,217 +816,121 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
|
||||
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
# For non-gated MoE, allocate 1 scale for w13.
|
||||
if self.moe.is_act_and_mul:
|
||||
w13_weight_scale_shape = (num_experts, 2)
|
||||
else:
|
||||
w13_weight_scale_shape = (num_experts, 1)
|
||||
w13_weight_scale = PerTensorScaleParameter(
|
||||
data=torch.full(
|
||||
w13_weight_scale_shape,
|
||||
1.0,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
w2_weight_scale = PerTensorScaleParameter(
|
||||
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
|
||||
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
# For non-gated MoE, allocate 1 scale for w13.
|
||||
w13_weight_scale = PerTensorScaleParameter(
|
||||
data=torch.full(
|
||||
(num_experts, 2 if self.moe.is_act_and_mul else 1),
|
||||
1.0,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
w2_weight_scale = PerTensorScaleParameter(
|
||||
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Set weight loader attributes for scales
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
# INPUT SCALES - Per-tensor scaling for ModelOpt
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
w2_input_scale = PerTensorScaleParameter(
|
||||
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
# INPUT SCALES - Per-tensor scaling for ModelOpt
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
):
|
||||
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
|
||||
fp8_backend=self.fp8_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w2=w2,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Replace parameters with updated versions. Note that this helper
|
||||
# function ensures the replacement is compatible with RL weight reloads.
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_scale)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_scale)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
layer=layer,
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
)
|
||||
w2_input_scale = PerTensorScaleParameter(
|
||||
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Process FP8 MoE weights after loading from serialized checkpoint.
|
||||
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
||||
"""
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
w13_input_scale = layer.w13_input_scale
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
self._maybe_pad_intermediate_for_flashinfer(layer)
|
||||
# Per tensor kernels require single activation scale. Use the max.
|
||||
w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
|
||||
w13_input_scale, w2_input_scale
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
per_tensor_dequantize,
|
||||
# Per tensor kernels require single weight scale for w13 per expert, but
|
||||
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
|
||||
w13,
|
||||
w13_scale,
|
||||
shard_size,
|
||||
num_experts=layer.w13_weight.shape[0],
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
|
||||
# Handle scale parameters
|
||||
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max of the w1 and w3 scales
|
||||
# then dequant and requant each expert.
|
||||
if (
|
||||
layer.w13_weight_scale.dim() == 2
|
||||
and layer.w13_weight_scale.shape[1] == 2
|
||||
):
|
||||
assert self.moe.is_act_and_mul, (
|
||||
"w13_weight_scale should have 2 elements per expert "
|
||||
"only for gated MoE"
|
||||
)
|
||||
# Get the maximum scale across w1 and w3 for each expert
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
|
||||
# Requantize each expert's weights using the combined scale
|
||||
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
|
||||
# where the first intermediate_size rows are w1, the next are w3
|
||||
intermediate_size = layer.w13_weight.shape[1] // 2
|
||||
for expert_id in range(layer.w13_weight.shape[0]):
|
||||
start = 0
|
||||
for shard_id in range(2): # w1 and w3
|
||||
# Dequantize using the original scale for this shard
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][
|
||||
start : start + intermediate_size, :
|
||||
],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
# Requantize using the combined max scale
|
||||
|
||||
(
|
||||
layer.w13_weight[expert_id][
|
||||
start : start + intermediate_size, :
|
||||
],
|
||||
_,
|
||||
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
|
||||
start += intermediate_size
|
||||
|
||||
# Update the scale parameter to be per-expert
|
||||
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
|
||||
layer.w2_weight_scale = Parameter(
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
# Input scales must be equal for each expert in fp8 MoE layers.
|
||||
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
|
||||
layer.w13_input_scale = Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
|
||||
layer.w2_input_scale = Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
if self.moe.is_act_and_mul:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
|
||||
# NOTE: this adds some attributes used by the trtllm kernel,
|
||||
# which does not conform to the modular kernels abstraction (yet).
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w13_input_scale=layer.w13_input_scale,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
|
||||
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||
|
||||
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
|
||||
used for GEMM to be divisible by a small alignment value. When this is
|
||||
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
||||
gate/up and down projection weights along the intermediate dim.
|
||||
"""
|
||||
if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
|
||||
return
|
||||
|
||||
# Current local intermediate size (per partition) is the K dimension of
|
||||
# the down projection.
|
||||
num_experts, hidden_size, intermediate = layer.w2_weight.shape
|
||||
|
||||
min_alignment = 16
|
||||
padded_intermediate = round_up(intermediate, min_alignment)
|
||||
|
||||
if padded_intermediate == intermediate:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Padding intermediate size from %d to %d for up/down projection weights.",
|
||||
intermediate,
|
||||
padded_intermediate,
|
||||
# Shuffle weights to runtime format and setup kernel.
|
||||
self._setup_kernel(
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
up_mult = 2 if self.moe.is_act_and_mul else 1
|
||||
padded_gate_up_dim = up_mult * padded_intermediate
|
||||
|
||||
# Pad w13 and w12 along its intermediate dimension.
|
||||
w13 = layer.w13_weight.data
|
||||
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
|
||||
padded_w13[:, : w13.shape[1], :] = w13
|
||||
layer.w13_weight.data = padded_w13
|
||||
|
||||
w2 = layer.w2_weight.data
|
||||
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
|
||||
padded_w2[:, :, :intermediate] = w2
|
||||
layer.w2_weight.data = padded_w2
|
||||
|
||||
if hasattr(layer, "intermediate_size_per_partition"):
|
||||
layer.intermediate_size_per_partition = padded_intermediate
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
# TRTLLM does not use modular kernels
|
||||
return None
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
layer.w13_weight_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
a1_gscale=(1.0 / layer.w13_input_scale),
|
||||
a2_gscale=(1.0 / layer.w2_input_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
)
|
||||
else:
|
||||
assert self.flashinfer_moe_backend is None
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
return make_fp8_moe_quant_config(
|
||||
fp8_backend=self.fp8_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -1044,17 +938,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
|
||||
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
|
||||
)
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
assert layer.activation == "silu", (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
|
||||
assert not layer.renormalize
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -1066,46 +961,34 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
assert layer.activation in ("silu", "relu2_no_mul"), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.kernel is not None
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
|
||||
|
||||
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
@@ -315,8 +315,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
|
||||
elif self.use_marlin:
|
||||
(workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = (
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
w13_weight, w2_weight, w13_weight_scale, w2_weight_scale = (
|
||||
prepare_fp8_moe_layer_for_marlin(
|
||||
layer,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -324,7 +324,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight_scale,
|
||||
)
|
||||
)
|
||||
layer.workspace = workspace
|
||||
# TODO(rob): once we apply refactor to Quark, switch to using
|
||||
# replace_parameter for compatibility with reloading in RL.
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def rotate_flashinfer_fp8_moe_weights(
|
||||
def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
|
||||
):
|
||||
"""Shuffle weights for for FI TRT-LLM Format"""
|
||||
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
|
||||
|
||||
epilogue_tile_m = 128
|
||||
@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights(
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_weight_scale,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_weight_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_flashinfer_per_tensor_scale_fp8(
|
||||
def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
|
||||
is_llama4 = layer.custom_routing_function == Llama4MoE.custom_routing_function
|
||||
assert is_llama4, "FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_cutlass_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
moe: FusedMoEConfig | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||
assert quant_config is not None
|
||||
|
||||
# Construct modular kernel with block-scale support when requested.
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
),
|
||||
select_cutlass_fp8_gemm_impl(
|
||||
moe=moe,
|
||||
quant_config=quant_config,
|
||||
out_dtype=hidden_states.dtype,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
),
|
||||
moe_parallel_config=layer.moe_parallel_config,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
|
||||
FlashinferMoeBackend.TENSORRT_LLM,
|
||||
)
|
||||
return backend in backends_supporting_global_sf
|
||||
|
||||
|
||||
def align_fp8_moe_weights_for_fi(
|
||||
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||
|
||||
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
|
||||
used for GEMM to be divisible by a small alignment value. When this is
|
||||
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
||||
gate/up and down projection weights along the intermediate dim.
|
||||
"""
|
||||
|
||||
# Current local intermediate size (per partition) is the K dimension of
|
||||
# the down projection.
|
||||
num_experts, hidden_size, intermediate = w2.shape
|
||||
|
||||
min_alignment = 16
|
||||
padded_intermediate = round_up(intermediate, min_alignment)
|
||||
|
||||
if padded_intermediate == intermediate:
|
||||
return w13, w2, intermediate
|
||||
|
||||
logger.info_once(
|
||||
"Padding intermediate size from %d to %d for up/down projection weights.",
|
||||
intermediate,
|
||||
padded_intermediate,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
up_mult = 2 if is_act_and_mul else 1
|
||||
padded_gate_up_dim = up_mult * padded_intermediate
|
||||
|
||||
# Pad w13 and w2 along its intermediate dimension.
|
||||
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
|
||||
padded_w13[:, : w13.shape[1], :] = w13
|
||||
|
||||
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
|
||||
padded_w2[:, :, :intermediate] = w2
|
||||
|
||||
return padded_w13, padded_w2, padded_intermediate
|
||||
|
||||
|
||||
def prepare_fp8_moe_layer_for_fi(
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
is_trtllm: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert Fp8 MoE weights to flashinfer kernel format
|
||||
|
||||
Note that for trtllm we update the model state dict
|
||||
with the scale format needed for these kernels.
|
||||
|
||||
Note that for per-tensor, we update the layer's
|
||||
intermediate size if the weights needed padding.
|
||||
"""
|
||||
|
||||
assert hasattr(layer.moe_config, "is_act_and_mul")
|
||||
block_quant = (
|
||||
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
|
||||
)
|
||||
|
||||
# Some FI MoE kernels require internal alignment of 16
|
||||
# for the gate-up proj. Pad the weights to respect this.
|
||||
if not block_quant:
|
||||
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
|
||||
w13,
|
||||
w2,
|
||||
layer.moe_config.is_act_and_mul,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
w13 = swap_w13_to_w31(w13)
|
||||
if block_quant:
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales. Note that we do not register
|
||||
# as nn.Parameters since they are not needed for weight-reloading.
|
||||
if is_trtllm and not block_quant:
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
return w13, w2, w13_scale
|
||||
|
||||
@@ -21,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
all_close_1d,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
@@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block(
|
||||
return wq, dg_ws
|
||||
|
||||
|
||||
def prepare_fp8_moe_layer_for_deepgemm(
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
block_shape: tuple[int],
|
||||
):
|
||||
w13, w13_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=w13,
|
||||
ws=w13_scale,
|
||||
quant_block_shape=block_shape,
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
w2, w2_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=w2,
|
||||
ws=w2_scale,
|
||||
quant_block_shape=block_shape,
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
|
||||
return w13, w2, w13_scale, w2_scale
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
@@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
replace_parameter(layer, scale_attr, dg_weight_scale)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
def process_fp8_weight_tensor_strategy_moe(
|
||||
weight: torch.Tensor,
|
||||
weight_scales: torch.Tensor,
|
||||
shard_size: int,
|
||||
num_experts: int,
|
||||
is_act_and_mul: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Process moe weights for tensor-wise quantization strategy."""
|
||||
max_scales = weight_scales.max(dim=1).values
|
||||
|
||||
# For w1 case (i.e. not w13): just collapse the last dim since
|
||||
# there is already just one scale per expert in this case.
|
||||
if not is_act_and_mul:
|
||||
assert weight_scales.shape[1] == 1
|
||||
return weight, weight_scales.max()
|
||||
|
||||
# For w13 case (common): require single scale for w13 per expert, but
|
||||
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
||||
for expert_id in range(num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
weight[expert_id][start : start + shard_size, :],
|
||||
weight_scales[expert_id][shard_id],
|
||||
)
|
||||
weight[expert_id][start : start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_scales[expert_id]
|
||||
)
|
||||
start += shard_size
|
||||
return weight, max_scales
|
||||
|
||||
|
||||
def process_fp8_input_tensor_strategy_moe(
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Process moe input scales for tensor-wise quantization strategy."""
|
||||
|
||||
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
|
||||
logger.info_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
|
||||
return w13_input_scale.max(), w2_input_scale.max()
|
||||
|
||||
@@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8:
|
||||
return _quant_fp8_method
|
||||
|
||||
|
||||
def get_marlin_input_dtype(prefix):
|
||||
def get_marlin_input_dtype(prefix: str | None = None):
|
||||
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
|
||||
return
|
||||
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
|
||||
|
||||
@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT,
|
||||
get_marlin_input_dtype,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
@@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin(
|
||||
replace_parameter(layer, "bias", bias)
|
||||
|
||||
|
||||
def prepare_moe_fp8_layer_for_marlin(
|
||||
def prepare_fp8_moe_layer_for_marlin(
|
||||
layer: torch.nn.Module,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor, # workspace
|
||||
torch.Tensor, # w13_weight
|
||||
torch.Tensor, # w2_weight
|
||||
torch.Tensor, # w13_weight_scale
|
||||
torch.Tensor, # w2_weight_scale
|
||||
]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Shuffle weights and scales into marlin format.
|
||||
|
||||
Note that this function has the side effect of adding a `workspace`
|
||||
attribute to the layer. This `workspace` does not need to be
|
||||
registered as a Parameter as it is not used during weight reloading.
|
||||
"""
|
||||
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads."
|
||||
)
|
||||
input_dtype = get_marlin_input_dtype()
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
raise NotImplementedError("Marlin W8A8 is not supported.")
|
||||
|
||||
@@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
workspace = marlin_make_workspace_new(device, 4)
|
||||
# NOTE(rob): we do not need to register the workspace as a param
|
||||
# because it is not used as part of the weight reloading process.
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT
|
||||
@@ -310,13 +315,7 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
w13_weight_scale = permute_scales(w13_weight_scale, "w13")
|
||||
w2_weight_scale = permute_scales(w2_weight_scale, "w2")
|
||||
|
||||
return (
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
)
|
||||
return w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
|
||||
|
||||
|
||||
def pack_fp8_to_int32(
|
||||
|
||||
Reference in New Issue
Block a user