[Hardware][SM100] Add TRTLLM Kernel for INT4 W4A16 Kernel. (#32437)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
272
tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py
Normal file
272
tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test comparing Marlin INT4 MoE vs FlashInfer TRT-LLM MXINT4 MoE."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
|
||||
grouped_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
|
||||
prepare_static_weights_for_trtllm_mxint4_moe,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def mxint4_quantize(
|
||||
x: torch.Tensor, sf_vec_size: int = 32
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize BF16 tensor to MXINT4 with block scaling (group_size=sf_vec_size).
|
||||
|
||||
Returns:
|
||||
- uint8 packed (2 INT4/byte): [..., k//2] - stores SIGNED INT4 [-8, 7]
|
||||
- scales in BF16: [..., k//sf_vec_size]
|
||||
"""
|
||||
x_reshaped = x.reshape(-1, sf_vec_size)
|
||||
x_max = x_reshaped.max(dim=-1, keepdim=True)[0].to(torch.float32)
|
||||
x_min = x_reshaped.min(dim=-1, keepdim=True)[0].to(torch.float32)
|
||||
x_max = x_max * 8.0 / 7.0
|
||||
amax = torch.where(x_max > -x_min, x_max, -x_min)
|
||||
scales = amax / 8.0
|
||||
x_scaled = x_reshaped * scales.reciprocal()
|
||||
x_int8 = (
|
||||
x_scaled.round().clamp(-8, 7).to(torch.int8).reshape(-1, sf_vec_size // 2, 2)
|
||||
)
|
||||
x_int4 = (x_int8[..., 0] & 0x0F) | ((x_int8[..., 1] & 0x0F) << 4)
|
||||
return (
|
||||
x_int4.to(torch.uint8).reshape(*x.shape[:-1], x.shape[-1] // 2),
|
||||
scales.to(x.dtype).reshape(*x.shape[:-1], x.shape[-1] // sf_vec_size),
|
||||
)
|
||||
|
||||
|
||||
def mxint4_quantize_moe_weights(
|
||||
weights_bf16: torch.Tensor, group_size: int = 32
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize MoE weights [e, n, k] to MxInt4 format.
|
||||
|
||||
Args:
|
||||
weights_bf16: BF16 weights of shape [num_experts, out_features, in_features]
|
||||
group_size: Quantization group size (default: 32)
|
||||
|
||||
Returns:
|
||||
- weights_mxint4: Quantized weights [e, n, k//2] uint8
|
||||
- scales_mxint4: Quantization scales [e, n, k//group_size] bf16
|
||||
"""
|
||||
e = weights_bf16.shape[0]
|
||||
weight_list = []
|
||||
scale_list = []
|
||||
|
||||
for i in range(e):
|
||||
w_q, w_s = mxint4_quantize(weights_bf16[i], sf_vec_size=group_size)
|
||||
weight_list.append(w_q)
|
||||
scale_list.append(w_s)
|
||||
|
||||
return torch.stack(weight_list), torch.stack(scale_list)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"mxint4_quantize",
|
||||
"mxint4_quantize_moe_weights",
|
||||
"marlin_quantize_moe_weights",
|
||||
]
|
||||
|
||||
|
||||
def marlin_quantize_moe_weights(
|
||||
weights_bf16: torch.Tensor, group_size: int = 32
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize MoE weights [e, n, k] to Marlin INT4 format.
|
||||
|
||||
Args:
|
||||
weights_bf16: BF16 weights of shape [num_experts, out_features, in_features]
|
||||
group_size: Quantization group size (default: 32)
|
||||
|
||||
Returns:
|
||||
- weights_marlin: Marlin quantized weights [e, k//8, n] int32
|
||||
- scales_marlin: Marlin quantization scales [e, k//group_size, n] bf16
|
||||
"""
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize,
|
||||
)
|
||||
|
||||
e, n, k = weights_bf16.shape
|
||||
weight_list = []
|
||||
scale_list = []
|
||||
|
||||
for i in range(e):
|
||||
# Transpose for Marlin: [n, k] → [k, n]
|
||||
w_t = weights_bf16[i].T.contiguous()
|
||||
_, w_q, w_s, _, _, _ = marlin_quantize(
|
||||
w_t, scalar_types.uint4b8, group_size, act_order=False
|
||||
)
|
||||
weight_list.append(w_q)
|
||||
scale_list.append(w_s)
|
||||
|
||||
# Stack to get [e, ...] shape
|
||||
weights_marlin = torch.stack(weight_list) # [e, k // 8, n]
|
||||
scales_marlin = torch.stack(scale_list) # [e, k // group_size, n]
|
||||
|
||||
return weights_marlin, scales_marlin
|
||||
|
||||
|
||||
TRTLLM_GEN_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRTLLM_GEN_AVAILABLE, reason="Skip for non SM100")
|
||||
@pytest.mark.parametrize("m", [1, 33])
|
||||
@pytest.mark.parametrize("n", [7168])
|
||||
@pytest.mark.parametrize("k", [512])
|
||||
@pytest.mark.parametrize("e", [384])
|
||||
@pytest.mark.parametrize("topk", [8])
|
||||
@pytest.mark.parametrize("group_size", [32])
|
||||
def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group_size):
|
||||
"""Compare Marlin INT4 MoE vs FlashInfer TRT-LLM MXINT4 MoE.
|
||||
|
||||
Uses mxint4_quantize() to generate common INT4 weights + BF16 scales,
|
||||
then runs both Marlin and TRT-LLM kernels and compares outputs.
|
||||
"""
|
||||
pytest.importorskip("flashinfer")
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_INT4", "1")
|
||||
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# DeepSeekV3 routing config (from Kimi-K2-Thinking config.json)
|
||||
n_group = 1 # n_group from model config
|
||||
topk_group = 1 # topk_group from model config
|
||||
routed_scaling = 2.827 # routed_scaling_factor from model config
|
||||
|
||||
# Input - realistic activation range for LLM (after LayerNorm: mean~0, std~1)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) * 0.5
|
||||
|
||||
# Generate routing logits and bias (DeepSeekV3 expects float logits)
|
||||
# Realistic ranges: logits typically [-3, 3], bias [-2, 2]
|
||||
routing_logits = torch.randn((m, e), device="cuda", dtype=torch.float32) * 1.5
|
||||
routing_bias = torch.randn(e, device="cuda", dtype=torch.float32) * 0.8
|
||||
|
||||
# 1. Generate BF16 weights (SHARED between both paths)
|
||||
# Realistic weight initialization: Xavier/Glorot uniform scaling
|
||||
# std = sqrt(2 / (fan_in + fan_out))
|
||||
std_w1 = (2.0 / (k + 2 * n)) ** 0.5
|
||||
std_w2 = (2.0 / (n + k)) ** 0.5
|
||||
w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) * std_w1
|
||||
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) * std_w2
|
||||
|
||||
# === Path 1: TRT-LLM FlashInfer MXINT4 MoE ===
|
||||
# Similar to: if self.use_flashinfer_mxint4_moe
|
||||
# Quantize using MXINT4 method (signed INT4)
|
||||
w1_int4, w1_scales = mxint4_quantize_moe_weights(w1_bf16, group_size)
|
||||
w2_int4, w2_scales = mxint4_quantize_moe_weights(w2_bf16, group_size)
|
||||
|
||||
trtllm_weights = prepare_static_weights_for_trtllm_mxint4_moe(
|
||||
gemm1_weights=w1_int4,
|
||||
gemm1_scales=w1_scales,
|
||||
gemm2_weights=w2_int4,
|
||||
gemm2_scales=w2_scales,
|
||||
)
|
||||
|
||||
from flashinfer import RoutingMethodType
|
||||
from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe
|
||||
|
||||
# Routing handled internally by trtllm_mxint4_block_scale_moe
|
||||
trtllm_output = trtllm_mxint4_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias.to(torch.bfloat16),
|
||||
hidden_states=a,
|
||||
gemm1_weights=trtllm_weights["gemm1_weights"],
|
||||
gemm1_weights_scale=trtllm_weights["gemm1_scales"],
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=trtllm_weights["gemm2_weights"],
|
||||
gemm2_weights_scale=trtllm_weights["gemm2_scales"],
|
||||
num_experts=e,
|
||||
top_k=topk,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=n,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=e,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||
enable_pdl=None,
|
||||
output=None,
|
||||
tune_max_num_tokens=8192,
|
||||
).to(dtype)
|
||||
|
||||
# === Path 2: Marlin INT4 MoE ===
|
||||
# Similar to: else (non-flashinfer path)
|
||||
# Quantize using Marlin's method (UINT4b8)
|
||||
w1_marlin, w1_scales_marlin = marlin_quantize_moe_weights(w1_bf16, group_size)
|
||||
w2_marlin, w2_scales_marlin = marlin_quantize_moe_weights(w2_bf16, group_size)
|
||||
|
||||
# Use production routing kernel (same as router.select_experts internally uses)
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states=a,
|
||||
gating_output=routing_logits,
|
||||
topk=topk,
|
||||
renormalize=False, # DeepSeekV3 doesn't renormalize
|
||||
num_expert_group=n_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func="sigmoid", # DeepSeekV3 uses sigmoid
|
||||
routed_scaling_factor=routed_scaling,
|
||||
e_score_correction_bias=routing_bias,
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
w1_marlin,
|
||||
w2_marlin,
|
||||
None,
|
||||
None,
|
||||
w1_scales_marlin,
|
||||
w2_scales_marlin,
|
||||
None, # gating_output not needed when topk_weights/ids provided
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
global_scale1=None,
|
||||
global_scale2=None,
|
||||
g_idx1=None,
|
||||
g_idx2=None,
|
||||
input_global_scale1=None,
|
||||
input_global_scale2=None,
|
||||
sort_indices1=None,
|
||||
sort_indices2=None,
|
||||
w1_zeros=None,
|
||||
w2_zeros=None,
|
||||
input_dtype=dtype,
|
||||
quant_type_id=scalar_types.uint4b8.id,
|
||||
is_k_full=True,
|
||||
)
|
||||
|
||||
# Sanity check: manually compute BF16 reference for comparison
|
||||
# Use same routing as Marlin path for consistency
|
||||
bf16_output = torch.zeros((m, k), device="cuda", dtype=dtype)
|
||||
for token_idx in range(m):
|
||||
for expert_rank in range(topk):
|
||||
expert_id = topk_ids[token_idx, expert_rank].item()
|
||||
weight = topk_weights[token_idx, expert_rank].item()
|
||||
# w1: [2*n, k] @ [k] -> [2*n]
|
||||
up_gate = a[token_idx] @ w1_bf16[expert_id].T # [2*n]
|
||||
gate, up = up_gate.chunk(2, dim=0)
|
||||
intermediate = torch.nn.functional.silu(gate) * up # [n]
|
||||
# w2: [k, n] @ [n] -> [k]
|
||||
expert_out = intermediate @ w2_bf16[expert_id].T # [k]
|
||||
bf16_output[token_idx] += weight * expert_out
|
||||
# Compare against BF16 reference.
|
||||
torch.testing.assert_close(marlin_output, bf16_output, atol=0.3, rtol=1.0)
|
||||
torch.testing.assert_close(trtllm_output, bf16_output, atol=0.3, rtol=1.0)
|
||||
|
||||
# Compare against each other for sanity.
|
||||
# Note: Different quantization schemes (UINT4b8 vs signed MXINT4) cause
|
||||
# some differences
|
||||
torch.testing.assert_close(marlin_output, trtllm_output, atol=0.3, rtol=6.0)
|
||||
11
vllm/envs.py
11
vllm/envs.py
@@ -174,6 +174,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_INT4: bool = False
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
|
||||
"latency"
|
||||
)
|
||||
@@ -1240,18 +1241,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool(
|
||||
int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0"))
|
||||
),
|
||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||
# Allow use of FlashInfer BF16 MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))
|
||||
),
|
||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||
# Allow use of FlashInfer FP8 MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))
|
||||
),
|
||||
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
|
||||
# Allow use of FlashInfer NVFP4 MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))
|
||||
),
|
||||
# Allow use of FlashInfer MxInt4 MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_INT4": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_FLASHINFER_MOE_INT4", "0"))
|
||||
),
|
||||
# If set to 1, use the FlashInfer
|
||||
# MXFP8 (activation) x MXFP4 (weight) MoE backend.
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool(
|
||||
|
||||
@@ -1138,6 +1138,11 @@ class FusedMoE(CustomOp):
|
||||
return False if return_success else None
|
||||
# Hereafter, `expert_id` is local physical id
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
@@ -1145,7 +1150,10 @@ class FusedMoE(CustomOp):
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
):
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
if is_transposed:
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
else:
|
||||
loaded_weight = loaded_weight
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.")
|
||||
@@ -1183,10 +1191,6 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
@@ -63,6 +63,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
|
||||
flashinfer_trtllm_mxint4_moe,
|
||||
is_flashinfer_mxint4_moe_available,
|
||||
prepare_static_weights_for_trtllm_mxint4_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
@@ -1247,8 +1252,89 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
self.actorder = weight_quant.actorder
|
||||
|
||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
|
||||
self.use_marlin = True
|
||||
|
||||
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
|
||||
self.use_flashinfer_mxint4_moe = (
|
||||
is_flashinfer_mxint4_moe_available()
|
||||
and self.group_size == 32
|
||||
and weight_quant.num_bits == 4
|
||||
)
|
||||
self.kernel_backend = (
|
||||
"Flashinfer" if self.use_flashinfer_mxint4_moe else "Marlin"
|
||||
)
|
||||
logger.info_once(
|
||||
f"Using {self.kernel_backend} backend for WNA16 MoE "
|
||||
f"(group_size={self.group_size}, num_bits={self.num_bits})",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
def get_weight_shape(
|
||||
self,
|
||||
weight_name: str,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
num_groups_w2: int | None = None,
|
||||
num_groups_w13: int | None = None,
|
||||
) -> tuple[int, int, int]:
|
||||
"""
|
||||
Get the shape of the weight based on the weight name, number of experts
|
||||
hidden size, intermediate size per partition, number of groups for w2,
|
||||
and number of groups for w13. Pass in num_groups_w2 and num_groups_w13
|
||||
for weight scales.
|
||||
"""
|
||||
if weight_name == "w13_scale":
|
||||
assert num_groups_w13 is not None, (
|
||||
"num_groups_w13 must be provided for weight scales"
|
||||
)
|
||||
if weight_name == "w2_scale":
|
||||
assert num_groups_w2 is not None, (
|
||||
"num_groups_w2 must be provided for weight scales"
|
||||
)
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
shape_map = {
|
||||
"w13_weight": {
|
||||
"Flashinfer": (
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size // self.packed_factor,
|
||||
),
|
||||
"Marlin": (
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
),
|
||||
},
|
||||
"w13_scale": {
|
||||
"Flashinfer": (
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
num_groups_w13,
|
||||
),
|
||||
"Marlin": (
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
),
|
||||
},
|
||||
"w2_weight": {
|
||||
"Flashinfer": (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
),
|
||||
"Marlin": (
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
),
|
||||
},
|
||||
"w2_scale": {
|
||||
"Flashinfer": (num_experts, hidden_size, num_groups_w2),
|
||||
"Marlin": (num_experts, num_groups_w2, hidden_size),
|
||||
},
|
||||
}
|
||||
return shape_map[weight_name][self.kernel_backend]
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1260,19 +1346,23 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
is_transposed = self.kernel_backend != "Flashinfer"
|
||||
extra_weight_attrs.update(
|
||||
{"is_transposed": True, "quant_method": self.strategy}
|
||||
{"is_transposed": is_transposed, "quant_method": self.strategy}
|
||||
)
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
*self.get_weight_shape(
|
||||
"w13_weight",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1282,9 +1372,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
*self.get_weight_shape(
|
||||
"w2_weight",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1315,9 +1408,13 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
w13_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
*self.get_weight_shape(
|
||||
"w13_scale",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
num_groups_w13=num_groups_w13,
|
||||
),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1326,7 +1423,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
|
||||
torch.ones(
|
||||
*self.get_weight_shape(
|
||||
"w2_scale",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
num_groups_w2=num_groups_w2,
|
||||
),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
@@ -1396,6 +1502,27 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||
device = layer.w13_weight_g_idx.device
|
||||
if self.kernel_backend == "Flashinfer":
|
||||
dict_weights_mxint4 = prepare_static_weights_for_trtllm_mxint4_moe(
|
||||
layer.w13_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_weight_scale,
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w13_weight_packed", dict_weights_mxint4["gemm1_weights"]
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w13_weight_scale", dict_weights_mxint4["gemm1_scales"]
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight_packed", dict_weights_mxint4["gemm2_weights"]
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight_scale", dict_weights_mxint4["gemm2_scales"]
|
||||
)
|
||||
return None
|
||||
|
||||
is_a_8bit = (
|
||||
self.marlin_input_dtype is not None
|
||||
and self.marlin_input_dtype.itemsize == 1
|
||||
@@ -1560,6 +1687,35 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.kernel_backend == "Flashinfer"
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel_backend == "Flashinfer"
|
||||
return flashinfer_trtllm_mxint4_moe(
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
w13_weight_packed=layer.w13_weight_packed,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w2_weight_packed=layer.w2_weight_packed,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
intermediate_size_per_partition=layer.intermediate_size_per_partition,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
ep_rank=layer.ep_rank,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1567,6 +1723,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel_backend == "Marlin"
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility helpers for MxInt4 + FlashInfer fused-MoE path"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
|
||||
__all__ = [
|
||||
"prepare_static_weights_for_trtllm_mxint4_moe",
|
||||
"flashinfer_trtllm_mxint4_moe",
|
||||
"is_flashinfer_mxint4_moe_available",
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_flashinfer_mxint4_moe_available() -> bool:
|
||||
"""Return `True` when FlashInfer MxInt4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_INT4
|
||||
and has_flashinfer_trtllm_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
|
||||
def prepare_static_weights_for_trtllm_mxint4_moe(
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_scales: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_scales: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare MxInt4 weights for TRT-LLM kernel.
|
||||
|
||||
Input:
|
||||
gemm1_weights: [num_experts, 2*intermediate_size, hidden_size//8] int32
|
||||
(checkpoint uint4b8 packed) or uint8 (already packed signed int4)
|
||||
gemm1_scales: [num_experts, 2*intermediate_size, hidden_size//32] bf16
|
||||
gemm2_weights: [num_experts, hidden_size, intermediate_size//8] int32
|
||||
(checkpoint uint4b8 packed) or uint8 (already packed signed int4)
|
||||
gemm2_scales: [num_experts, hidden_size, intermediate_size//32] bf16
|
||||
|
||||
Returns:
|
||||
Dict with keys 'gemm1_weights', 'gemm1_scales', 'gemm2_weights',
|
||||
'gemm2_scales' containing shuffled/packed tensors ready for kernel
|
||||
"""
|
||||
from flashinfer import block_scale_interleave
|
||||
from flashinfer.fused_moe import (
|
||||
convert_to_block_layout,
|
||||
)
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
get_w2_permute_indices_with_cache,
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
reorder_w1w3_to_w3w1,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
)
|
||||
|
||||
device = gemm1_weights.device
|
||||
assert gemm1_weights.ndim == 3, (
|
||||
f"Expected a 3D gemm1_weights tensor, got {gemm1_weights.shape}"
|
||||
)
|
||||
assert gemm1_scales.ndim == 3, (
|
||||
f"Expected a 3D gemm1_scales tensor, got {gemm1_scales.shape}"
|
||||
)
|
||||
assert gemm2_weights.ndim == 3, (
|
||||
f"Expected a 3D gemm2_weights tensor, got {gemm2_weights.shape}"
|
||||
)
|
||||
assert gemm2_scales.ndim == 3, (
|
||||
f"Expected a 3D gemm2_scales tensor, got {gemm2_scales.shape}"
|
||||
)
|
||||
|
||||
# Convert checkpoint format (uint4b8 in int32) to signed int4
|
||||
# Checkpoint stores INT4 as unsigned [0, 15], kernel expects signed [-8, 7]
|
||||
if gemm1_weights.dtype == torch.int32 and gemm2_weights.dtype == torch.int32:
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(gemm1_weights)
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(gemm2_weights)
|
||||
|
||||
gemm1_weights, gemm1_scales = reorder_w1w3_to_w3w1(
|
||||
gemm1_weights, gemm1_scales, dim=-2
|
||||
)
|
||||
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
num_experts = gemm1_weights.shape[0]
|
||||
|
||||
# Convert quantized weights to proper formats -
|
||||
gemm1_weights_mxint4 = gemm1_weights.view(torch.uint8)
|
||||
assert gemm1_scales.dtype == torch.bfloat16
|
||||
gemm2_weights_mxint4 = gemm2_weights.view(torch.uint8)
|
||||
assert gemm2_scales.dtype == torch.bfloat16
|
||||
|
||||
epilogue_tile_m = 128
|
||||
gemm1_weights_mxint4_shuffled = []
|
||||
gemm1_scales_shuffled = []
|
||||
gemm2_weights_mxint4_shuffled = []
|
||||
gemm2_scales_shuffled = []
|
||||
|
||||
for i in range(num_experts):
|
||||
# Calculate the permute indices for the following:
|
||||
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||
# for both w3_w1 and w2 weights and scale factors
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
_cache_permute_indices,
|
||||
gemm1_weights_mxint4[i],
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled = gemm1_weights_mxint4[i][
|
||||
permute_indices.to(gemm1_weights.device)
|
||||
].contiguous()
|
||||
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
_cache_permute_indices,
|
||||
gemm1_scales[i],
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=32,
|
||||
).to(device)
|
||||
gemm1_scales_shuffled.append(
|
||||
block_scale_interleave(gemm1_scales[i][permute_sf_indices].contiguous())
|
||||
)
|
||||
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
gemm2_weights_mxint4[i],
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled = gemm2_weights_mxint4[i][
|
||||
permute_indices.to(gemm2_weights.device)
|
||||
].contiguous()
|
||||
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
gemm2_scales[i],
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
block_scale_interleave(
|
||||
gemm2_scales[i][permute_sf_indices.to(gemm2_scales.device)].contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
block_k = 128
|
||||
gemm1_weights_shuffled = convert_to_block_layout(
|
||||
gemm1_weights_shuffled.view(torch.uint8), block_k
|
||||
)
|
||||
gemm2_weights_shuffled = convert_to_block_layout(
|
||||
gemm2_weights_shuffled.view(torch.uint8), block_k
|
||||
)
|
||||
|
||||
gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled)
|
||||
gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled)
|
||||
|
||||
gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled)
|
||||
gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled)
|
||||
gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16)
|
||||
gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16)
|
||||
return {
|
||||
"gemm1_weights": gemm1_weights_mxint4_shuffled,
|
||||
"gemm1_scales": gemm1_scales_shuffled,
|
||||
"gemm2_weights": gemm2_weights_mxint4_shuffled,
|
||||
"gemm2_scales": gemm2_scales_shuffled,
|
||||
}
|
||||
|
||||
|
||||
def flashinfer_trtllm_mxint4_moe(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
w13_weight_packed: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_packed: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
intermediate_size_per_partition: int,
|
||||
local_num_experts: int,
|
||||
ep_rank: int = 0,
|
||||
num_expert_group: int | None = None,
|
||||
topk_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routing_method_type: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM MxInt4 MoE kernel.
|
||||
|
||||
Args:
|
||||
x: Input hidden states. dtype: bfloat16
|
||||
router_logits: Router logits for expert selection. dtype: bfloat16/float32
|
||||
w13_weight_packed: Packed gate+up weights. dtype: uint8
|
||||
w13_weight_scale: Scales for gate+up weights. dtype: bfloat16
|
||||
w2_weight_packed: Packed down weights. dtype: uint8
|
||||
w2_weight_scale: Scales for down weights. dtype: bfloat16
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
top_k: Number of experts to select per token
|
||||
intermediate_size_per_partition: Intermediate size per partition
|
||||
local_num_experts: Number of experts on this rank
|
||||
ep_rank: Expert parallelism rank (default: 0)
|
||||
num_expert_group: Number of expert groups (default: None -> 0)
|
||||
topk_group: Top-k within groups (default: None -> 0)
|
||||
e_score_correction_bias: Optional routing bias. dtype: bfloat16
|
||||
routing_method_type: FlashInfer RoutingMethodType enum value
|
||||
|
||||
Returns:
|
||||
Output tensor from MoE layer. dtype: same as x (bfloat16)
|
||||
"""
|
||||
from flashinfer import RoutingMethodType
|
||||
from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe
|
||||
|
||||
assert x.dtype == torch.bfloat16, f"x dtype must be bfloat16, got {x.dtype}"
|
||||
assert w13_weight_packed.dtype == torch.uint8, (
|
||||
f"w13_weight_packed dtype must be uint8, got {w13_weight_packed.dtype}"
|
||||
)
|
||||
assert w13_weight_scale.dtype == torch.bfloat16, (
|
||||
f"w13_weight_scale dtype must be bfloat16, got {w13_weight_scale.dtype}"
|
||||
)
|
||||
assert w2_weight_packed.dtype == torch.uint8, (
|
||||
f"w2_weight_packed dtype must be uint8, got {w2_weight_packed.dtype}"
|
||||
)
|
||||
assert w2_weight_scale.dtype == torch.bfloat16, (
|
||||
f"w2_weight_scale dtype must be bfloat16, got {w2_weight_scale.dtype}"
|
||||
)
|
||||
|
||||
routing_bias = None
|
||||
if e_score_correction_bias is not None:
|
||||
routing_bias = e_score_correction_bias.to(torch.bfloat16)
|
||||
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
out = trtllm_mxint4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=x,
|
||||
gemm1_weights=w13_weight_packed.data,
|
||||
gemm1_weights_scale=w13_weight_scale.data,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2_weight_packed.data,
|
||||
gemm2_weights_scale=w2_weight_scale.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
intermediate_size=intermediate_size_per_partition,
|
||||
local_expert_offset=ep_rank * local_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=routing_method_type,
|
||||
enable_pdl=None,
|
||||
output=None,
|
||||
tune_max_num_tokens=8192,
|
||||
).to(x.dtype)
|
||||
|
||||
return out
|
||||
@@ -129,12 +129,11 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper(
|
||||
"flashinfer", "scaled_fp4_grouped_quantize"
|
||||
)
|
||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||
"flashinfer", "nvfp4_block_scale_interleave"
|
||||
"flashinfer.fp4_quantization", "block_scale_interleave"
|
||||
)
|
||||
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer", "trtllm_fp4_block_scale_moe"
|
||||
)
|
||||
|
||||
# Special case for autotune since it returns a context manager
|
||||
autotune = _lazy_import_wrapper(
|
||||
"flashinfer.autotuner",
|
||||
@@ -196,6 +195,7 @@ def has_flashinfer_trtllm_fused_moe() -> bool:
|
||||
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
|
||||
]
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
|
||||
Reference in New Issue
Block a user