|
|
|
|
@@ -1,6 +1,7 @@
|
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -36,6 +37,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
|
|
|
from vllm.model_executor.parameter import (ModelWeightParameter,
|
|
|
|
|
PerTensorScaleParameter)
|
|
|
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
|
from vllm.utils import next_power_of_2
|
|
|
|
|
from vllm.utils.flashinfer import has_flashinfer_moe
|
|
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
@@ -44,6 +46,11 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
|
|
|
|
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlashinferMoeBackend(Enum):
|
|
|
|
|
TENSORRT_LLM = "TensorRT-LLM"
|
|
|
|
|
CUTLASS = "CUTLASS"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelOptFp8Config(QuantizationConfig):
|
|
|
|
|
"""Config class for ModelOpt FP8."""
|
|
|
|
|
|
|
|
|
|
@@ -185,7 +192,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|
|
|
|
Args: quant_config: The ModelOpt quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: ModelOptFp8Config):
|
|
|
|
|
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
|
|
|
|
self.quant_config = quant_config
|
|
|
|
|
self.fp8_linear = Fp8LinearOp(
|
|
|
|
|
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
|
|
|
|
@@ -265,7 +272,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
|
|
|
quant_config: The ModelOpt quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: ModelOptFp8Config):
|
|
|
|
|
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
|
|
|
|
self.quant_config = quant_config
|
|
|
|
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
|
|
|
cutlass_fp8_supported)
|
|
|
|
|
@@ -670,7 +677,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|
|
|
|
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
|
|
|
|
exclude_modules, group_size)
|
|
|
|
|
|
|
|
|
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
|
|
|
|
def is_layer_excluded(self, prefix: str,
|
|
|
|
|
exclude_modules: list[str]) -> bool:
|
|
|
|
|
import regex as re
|
|
|
|
|
for pattern in exclude_modules:
|
|
|
|
|
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
|
|
|
|
@@ -714,7 +722,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|
|
|
|
Args: quant_config: The ModelOpt quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: ModelOptNvFp4Config):
|
|
|
|
|
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
|
|
|
|
self.quant_config = quant_config
|
|
|
|
|
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
|
|
|
|
self.use_marlin = False
|
|
|
|
|
@@ -859,6 +867,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|
|
|
|
return out.view(*output_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
|
|
|
|
|
# Guess tokens per expert assuming perfect expert distribution first.
|
|
|
|
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
|
|
|
|
# And pad the number to the next power of 2.
|
|
|
|
|
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
|
|
|
|
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
|
|
|
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
|
|
|
return tile_tokens_dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
"""
|
|
|
|
|
MoE Method for FP4 Quantization.
|
|
|
|
|
@@ -866,22 +884,40 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
quant_config: NVFP4 Quant Config
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: ModelOptNvFp4Config):
|
|
|
|
|
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
|
|
|
|
self.quant_config = quant_config
|
|
|
|
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
|
|
|
|
detect_nvfp4_moe_support)
|
|
|
|
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
|
|
|
|
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
|
|
|
|
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
|
|
|
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
|
|
|
|
self.use_marlin = _nvfp4.use_marlin
|
|
|
|
|
self.flashinfer_moe_backend = None
|
|
|
|
|
|
|
|
|
|
self.fused_experts = None # type: ignore
|
|
|
|
|
if self.allow_flashinfer:
|
|
|
|
|
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
|
|
|
|
if flashinfer_moe_backend == "throughput":
|
|
|
|
|
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
|
|
|
|
logger.info_once("Using FlashInfer CUTLASS kernels for "
|
|
|
|
|
"ModelOptNvFp4FusedMoE.")
|
|
|
|
|
elif flashinfer_moe_backend == "latency":
|
|
|
|
|
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
|
|
|
|
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
|
|
|
|
|
"ModelOptNvFp4FusedMoE.")
|
|
|
|
|
else:
|
|
|
|
|
allowed_backends = ["throughput", "latency"]
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
|
|
|
|
f" expected one of {allowed_backends}")
|
|
|
|
|
|
|
|
|
|
self.fused_experts: Optional[
|
|
|
|
|
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
|
|
|
|
|
|
|
|
|
|
def maybe_swap_experts_impl(
|
|
|
|
|
self,
|
|
|
|
|
moe_parallel_config: FusedMoEParallelConfig,
|
|
|
|
|
):
|
|
|
|
|
if not self.allow_flashinfer_cutlass:
|
|
|
|
|
if not self.allow_flashinfer:
|
|
|
|
|
return
|
|
|
|
|
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
|
|
|
|
moe_parallel_config)
|
|
|
|
|
@@ -897,8 +933,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
|
|
|
|
select_nvfp4_gemm_impl)
|
|
|
|
|
|
|
|
|
|
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
|
|
|
|
logger)
|
|
|
|
|
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
|
|
|
|
|
|
|
|
|
def uses_weight_scale_2_pattern(self) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
@@ -996,14 +1031,101 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
weight_loader=weight_loader)
|
|
|
|
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
|
|
|
|
|
|
|
|
def prepare_static_weight_layouts_for_trtllm_moe(
|
|
|
|
|
self,
|
|
|
|
|
gemm1_weights: torch.Tensor,
|
|
|
|
|
gemm2_weights: torch.Tensor,
|
|
|
|
|
gemm1_scales_linear_fp4_bytes: torch.Tensor,
|
|
|
|
|
gemm2_scales_linear_fp4_bytes: torch.Tensor,
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
intermediate_size: int,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
|
|
|
|
from flashinfer import (reorder_rows_for_gated_act_gemm,
|
|
|
|
|
shuffle_matrix_a, shuffle_matrix_sf_a)
|
|
|
|
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
|
|
|
|
|
|
|
|
|
# Convert quantized weights to proper formats
|
|
|
|
|
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
|
|
|
|
num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4
|
|
|
|
|
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
|
|
|
|
torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
|
|
|
|
|
hidden_size //
|
|
|
|
|
16) # fp8 scaling factors
|
|
|
|
|
|
|
|
|
|
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
|
|
|
|
num_experts, hidden_size, intermediate_size // 2) # packed fp4
|
|
|
|
|
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
|
|
|
|
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
|
|
|
|
intermediate_size //
|
|
|
|
|
16) # fp8 scaling factors
|
|
|
|
|
|
|
|
|
|
# Reorder rows of W1 and scales for fused gated activation
|
|
|
|
|
gemm1_weights_fp4_interleaved = []
|
|
|
|
|
gemm1_scales_fp4_interleaved = []
|
|
|
|
|
for i in range(num_experts):
|
|
|
|
|
gemm1_weights_fp4_interleaved.append(
|
|
|
|
|
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
|
|
|
|
|
gemm1_scales_fp4_interleaved.append(
|
|
|
|
|
reorder_rows_for_gated_act_gemm(
|
|
|
|
|
gemm1_scales_linear_fp4[i].clone()))
|
|
|
|
|
|
|
|
|
|
# Stack weights and scales for all experts
|
|
|
|
|
gemm1_weights_fp4_interleaved = torch.stack(
|
|
|
|
|
gemm1_weights_fp4_interleaved).reshape(num_experts,
|
|
|
|
|
2 * intermediate_size,
|
|
|
|
|
hidden_size // 2)
|
|
|
|
|
gemm1_scales_fp4_interleaved = torch.stack(
|
|
|
|
|
gemm1_scales_fp4_interleaved).reshape(num_experts,
|
|
|
|
|
2 * intermediate_size,
|
|
|
|
|
hidden_size // 16)
|
|
|
|
|
|
|
|
|
|
# Shuffle weights and scaling factors for transposed mma output
|
|
|
|
|
gemm1_weights_fp4_shuffled = []
|
|
|
|
|
gemm1_scales_fp4_shuffled = []
|
|
|
|
|
gemm2_weights_fp4_shuffled = []
|
|
|
|
|
gemm2_scales_fp4_shuffled = []
|
|
|
|
|
for i in range(num_experts):
|
|
|
|
|
gemm1_weights_fp4_shuffled.append(
|
|
|
|
|
shuffle_matrix_a(
|
|
|
|
|
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
|
|
|
|
|
epilogue_tile_m))
|
|
|
|
|
gemm1_scales_fp4_shuffled.append(
|
|
|
|
|
shuffle_matrix_sf_a(
|
|
|
|
|
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
|
|
|
|
|
epilogue_tile_m))
|
|
|
|
|
|
|
|
|
|
gemm2_weights_fp4_shuffled.append(
|
|
|
|
|
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
|
|
|
|
|
epilogue_tile_m))
|
|
|
|
|
gemm2_scales_fp4_shuffled.append(
|
|
|
|
|
shuffle_matrix_sf_a(
|
|
|
|
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
|
|
|
|
epilogue_tile_m))
|
|
|
|
|
|
|
|
|
|
# Stack weights for all experts
|
|
|
|
|
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
|
|
|
|
gemm1_scales_fp4_shuffled = (
|
|
|
|
|
torch.stack(gemm1_scales_fp4_shuffled).view(
|
|
|
|
|
torch.float8_e4m3fn).reshape(num_experts,
|
|
|
|
|
2 * intermediate_size,
|
|
|
|
|
hidden_size // 16))
|
|
|
|
|
|
|
|
|
|
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
|
|
|
|
gemm2_scales_fp4_shuffled = (
|
|
|
|
|
torch.stack(gemm2_scales_fp4_shuffled).view(
|
|
|
|
|
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
|
|
|
|
intermediate_size // 16))
|
|
|
|
|
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
|
|
|
|
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
|
|
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
|
|
|
# GEMM 1
|
|
|
|
|
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
|
|
|
|
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
|
|
|
|
|
# GEMM 1 processing
|
|
|
|
|
gemm1_weight = layer.w13_weight.data
|
|
|
|
|
gemm1_weight_scale = layer.w13_weight_scale.data
|
|
|
|
|
|
|
|
|
|
if self.allow_flashinfer_cutlass:
|
|
|
|
|
if self.allow_flashinfer:
|
|
|
|
|
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
|
|
|
|
gemm1_weight, gemm1_weight_scale, dim=-2)
|
|
|
|
|
|
|
|
|
|
@@ -1011,6 +1133,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
|
|
|
|
|
# Common processing for w13_weight_scale_2
|
|
|
|
|
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
|
|
|
|
layer.w13_weight_scale_2[:, 1]):
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
@@ -1021,26 +1144,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
|
|
|
|
|
# Common processing for input scales and alphas
|
|
|
|
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
|
|
|
|
torch.float32)
|
|
|
|
|
layer.g1_alphas = Parameter(
|
|
|
|
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
|
|
|
|
|
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
|
|
|
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
|
|
|
|
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
|
|
|
|
"Weight Blockscale must be represented as FP8-E4M3")
|
|
|
|
|
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
|
|
|
|
|
|
|
|
|
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
|
|
|
|
|
# This is for quantization, so we need to invert it.
|
|
|
|
|
layer.w13_input_scale_quant = Parameter(
|
|
|
|
|
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
|
|
|
|
|
|
|
|
|
# GEMM 2
|
|
|
|
|
# GEMM 2 processing
|
|
|
|
|
layer.g2_alphas = Parameter(
|
|
|
|
|
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
@@ -1049,15 +1164,63 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
layer.w2_input_scale_quant = Parameter(
|
|
|
|
|
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
|
|
|
|
|
|
|
|
|
|
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
|
|
|
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
|
|
|
|
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
|
|
|
|
"Weight Blockscale must be represented as FP8-E4M3")
|
|
|
|
|
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
|
|
|
|
# TensorRT-LLM specific processing
|
|
|
|
|
if self.allow_flashinfer and \
|
|
|
|
|
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
|
|
|
|
# Prepare static weights for TRT-LLM kernel
|
|
|
|
|
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
|
|
|
|
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
|
|
|
|
|
) = self.prepare_static_weight_layouts_for_trtllm_moe(
|
|
|
|
|
layer.w13_weight,
|
|
|
|
|
layer.w2_weight,
|
|
|
|
|
layer.w13_weight_scale,
|
|
|
|
|
layer.w2_weight_scale,
|
|
|
|
|
layer.w2_weight.size(-2), # hidden_size
|
|
|
|
|
layer.w13_weight.size(-2) // 2, # intermediate_size
|
|
|
|
|
layer.w13_weight.size(0), # num_experts
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
|
|
|
|
layer.gemm1_weights_fp4_shuffled = Parameter(
|
|
|
|
|
gemm1_weights_fp4_shuffled, requires_grad=False)
|
|
|
|
|
layer.gemm2_weights_fp4_shuffled = Parameter(
|
|
|
|
|
gemm2_weights_fp4_shuffled, requires_grad=False)
|
|
|
|
|
layer.gemm1_scales_fp4_shuffled = Parameter(
|
|
|
|
|
gemm1_scales_fp4_shuffled, requires_grad=False)
|
|
|
|
|
layer.gemm2_scales_fp4_shuffled = Parameter(
|
|
|
|
|
gemm2_scales_fp4_shuffled, requires_grad=False)
|
|
|
|
|
|
|
|
|
|
# Additional parameter needed for TRT-LLM
|
|
|
|
|
layer.g1_scale_c = Parameter(
|
|
|
|
|
(layer.w2_input_scale_quant * layer.g1_alphas).to(
|
|
|
|
|
torch.float32),
|
|
|
|
|
requires_grad=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Clean up weights that won't be used by TRT-LLM
|
|
|
|
|
del layer.w2_weight
|
|
|
|
|
del layer.w2_weight_scale
|
|
|
|
|
del layer.w13_weight
|
|
|
|
|
del layer.w13_weight_scale
|
|
|
|
|
else:
|
|
|
|
|
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
|
|
|
|
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
|
|
|
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
|
|
|
|
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
|
|
|
|
"Weight Blockscale must be represented as FP8-E4M3")
|
|
|
|
|
w13_blockscale_swizzled = swizzle_blockscale(
|
|
|
|
|
layer.w13_weight_scale)
|
|
|
|
|
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
|
|
|
|
|
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
|
|
|
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
|
|
|
|
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
|
|
|
|
"Weight Blockscale must be represented as FP8-E4M3")
|
|
|
|
|
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
|
|
|
|
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
layer.w2_weight = Parameter(layer.w2_weight.data,
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
|
|
|
|
|
if self.use_marlin:
|
|
|
|
|
prepare_moe_fp4_layer_for_marlin(layer)
|
|
|
|
|
@@ -1095,6 +1258,60 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
|
|
|
|
assert activation == "silu", "Only SiLU activation is supported."
|
|
|
|
|
|
|
|
|
|
if self.allow_flashinfer and \
|
|
|
|
|
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
|
|
|
|
import flashinfer
|
|
|
|
|
|
|
|
|
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
|
|
|
|
|
|
|
|
|
a1_gscale = layer.w13_input_scale_quant
|
|
|
|
|
(hidden_states_fp4,
|
|
|
|
|
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
|
|
|
|
x,
|
|
|
|
|
a1_gscale,
|
|
|
|
|
is_sf_swizzled_layout=False,
|
|
|
|
|
)
|
|
|
|
|
use_llama4_routing = \
|
|
|
|
|
custom_routing_function is Llama4MoE.custom_routing_function
|
|
|
|
|
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
|
|
|
|
|
if use_llama4_routing:
|
|
|
|
|
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
|
|
|
|
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
|
|
|
|
routing_logits=router_logits
|
|
|
|
|
if use_llama4_routing else router_logits.to(torch.float32),
|
|
|
|
|
routing_bias=e_score_correction_bias,
|
|
|
|
|
hidden_states=hidden_states_fp4,
|
|
|
|
|
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
|
|
|
|
torch.float8_e4m3fn).flatten(),
|
|
|
|
|
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
|
|
|
|
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
|
|
|
|
torch.float8_e4m3fn),
|
|
|
|
|
gemm1_bias=None,
|
|
|
|
|
gemm1_alpha=None,
|
|
|
|
|
gemm1_beta=None,
|
|
|
|
|
gemm1_clamp_limit=None,
|
|
|
|
|
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
|
|
|
|
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
|
|
|
|
torch.float8_e4m3fn),
|
|
|
|
|
gemm2_bias=None,
|
|
|
|
|
output1_scale_scalar=layer.g1_scale_c.data,
|
|
|
|
|
output1_scale_gate_scalar=layer.g1_alphas.data,
|
|
|
|
|
output2_scale_scalar=layer.g2_alphas.data,
|
|
|
|
|
num_experts=global_num_experts,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
n_group=num_expert_group,
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
intermediate_size=layer.intermediate_size_per_partition,
|
|
|
|
|
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
|
|
|
|
local_num_experts=layer.local_num_experts,
|
|
|
|
|
routed_scaling_factor=None,
|
|
|
|
|
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
|
|
|
|
|
layer.local_num_experts),
|
|
|
|
|
routing_method_type=routing_method_type,
|
|
|
|
|
do_finalize=True,
|
|
|
|
|
)[0]
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
|
|
|
hidden_states=x,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
@@ -1149,6 +1366,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
expert_map=expert_map,
|
|
|
|
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
|
|
|
else:
|
|
|
|
|
assert self.allow_flashinfer and \
|
|
|
|
|
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
|
|
|
|
out = flashinfer_fp4_cutlass_moe_forward(
|
|
|
|
|
self.fused_experts,
|
|
|
|
|
layer,
|
|
|
|
|
@@ -1160,4 +1379,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|
|
|
|
expert_map=expert_map,
|
|
|
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|