[Hardware][SM100] Add TRTLLM Kernel for INT4 W4A16 Kernel. (#32437)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -1138,6 +1138,11 @@ class FusedMoE(CustomOp):
|
||||
return False if return_success else None
|
||||
# Hereafter, `expert_id` is local physical id
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
@@ -1145,7 +1150,10 @@ class FusedMoE(CustomOp):
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
):
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
if is_transposed:
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
else:
|
||||
loaded_weight = loaded_weight
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.")
|
||||
@@ -1183,10 +1191,6 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
@@ -63,6 +63,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
|
||||
flashinfer_trtllm_mxint4_moe,
|
||||
is_flashinfer_mxint4_moe_available,
|
||||
prepare_static_weights_for_trtllm_mxint4_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
@@ -1247,8 +1252,89 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
self.actorder = weight_quant.actorder
|
||||
|
||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
|
||||
self.use_marlin = True
|
||||
|
||||
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
|
||||
self.use_flashinfer_mxint4_moe = (
|
||||
is_flashinfer_mxint4_moe_available()
|
||||
and self.group_size == 32
|
||||
and weight_quant.num_bits == 4
|
||||
)
|
||||
self.kernel_backend = (
|
||||
"Flashinfer" if self.use_flashinfer_mxint4_moe else "Marlin"
|
||||
)
|
||||
logger.info_once(
|
||||
f"Using {self.kernel_backend} backend for WNA16 MoE "
|
||||
f"(group_size={self.group_size}, num_bits={self.num_bits})",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
def get_weight_shape(
|
||||
self,
|
||||
weight_name: str,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
num_groups_w2: int | None = None,
|
||||
num_groups_w13: int | None = None,
|
||||
) -> tuple[int, int, int]:
|
||||
"""
|
||||
Get the shape of the weight based on the weight name, number of experts
|
||||
hidden size, intermediate size per partition, number of groups for w2,
|
||||
and number of groups for w13. Pass in num_groups_w2 and num_groups_w13
|
||||
for weight scales.
|
||||
"""
|
||||
if weight_name == "w13_scale":
|
||||
assert num_groups_w13 is not None, (
|
||||
"num_groups_w13 must be provided for weight scales"
|
||||
)
|
||||
if weight_name == "w2_scale":
|
||||
assert num_groups_w2 is not None, (
|
||||
"num_groups_w2 must be provided for weight scales"
|
||||
)
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
shape_map = {
|
||||
"w13_weight": {
|
||||
"Flashinfer": (
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size // self.packed_factor,
|
||||
),
|
||||
"Marlin": (
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
),
|
||||
},
|
||||
"w13_scale": {
|
||||
"Flashinfer": (
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
num_groups_w13,
|
||||
),
|
||||
"Marlin": (
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
),
|
||||
},
|
||||
"w2_weight": {
|
||||
"Flashinfer": (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
),
|
||||
"Marlin": (
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
),
|
||||
},
|
||||
"w2_scale": {
|
||||
"Flashinfer": (num_experts, hidden_size, num_groups_w2),
|
||||
"Marlin": (num_experts, num_groups_w2, hidden_size),
|
||||
},
|
||||
}
|
||||
return shape_map[weight_name][self.kernel_backend]
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1260,19 +1346,23 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
is_transposed = self.kernel_backend != "Flashinfer"
|
||||
extra_weight_attrs.update(
|
||||
{"is_transposed": True, "quant_method": self.strategy}
|
||||
{"is_transposed": is_transposed, "quant_method": self.strategy}
|
||||
)
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
*self.get_weight_shape(
|
||||
"w13_weight",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1282,9 +1372,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
*self.get_weight_shape(
|
||||
"w2_weight",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1315,9 +1408,13 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
w13_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
*self.get_weight_shape(
|
||||
"w13_scale",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
num_groups_w13=num_groups_w13,
|
||||
),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1326,7 +1423,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
|
||||
torch.ones(
|
||||
*self.get_weight_shape(
|
||||
"w2_scale",
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
num_groups_w2=num_groups_w2,
|
||||
),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
@@ -1396,6 +1502,27 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||
device = layer.w13_weight_g_idx.device
|
||||
if self.kernel_backend == "Flashinfer":
|
||||
dict_weights_mxint4 = prepare_static_weights_for_trtllm_mxint4_moe(
|
||||
layer.w13_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_weight_scale,
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w13_weight_packed", dict_weights_mxint4["gemm1_weights"]
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w13_weight_scale", dict_weights_mxint4["gemm1_scales"]
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight_packed", dict_weights_mxint4["gemm2_weights"]
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight_scale", dict_weights_mxint4["gemm2_scales"]
|
||||
)
|
||||
return None
|
||||
|
||||
is_a_8bit = (
|
||||
self.marlin_input_dtype is not None
|
||||
and self.marlin_input_dtype.itemsize == 1
|
||||
@@ -1560,6 +1687,35 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.kernel_backend == "Flashinfer"
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel_backend == "Flashinfer"
|
||||
return flashinfer_trtllm_mxint4_moe(
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
w13_weight_packed=layer.w13_weight_packed,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w2_weight_packed=layer.w2_weight_packed,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
intermediate_size_per_partition=layer.intermediate_size_per_partition,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
ep_rank=layer.ep_rank,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1567,6 +1723,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel_backend == "Marlin"
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility helpers for MxInt4 + FlashInfer fused-MoE path"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
|
||||
__all__ = [
|
||||
"prepare_static_weights_for_trtllm_mxint4_moe",
|
||||
"flashinfer_trtllm_mxint4_moe",
|
||||
"is_flashinfer_mxint4_moe_available",
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_flashinfer_mxint4_moe_available() -> bool:
|
||||
"""Return `True` when FlashInfer MxInt4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_INT4
|
||||
and has_flashinfer_trtllm_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
|
||||
def prepare_static_weights_for_trtllm_mxint4_moe(
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_scales: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_scales: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare MxInt4 weights for TRT-LLM kernel.
|
||||
|
||||
Input:
|
||||
gemm1_weights: [num_experts, 2*intermediate_size, hidden_size//8] int32
|
||||
(checkpoint uint4b8 packed) or uint8 (already packed signed int4)
|
||||
gemm1_scales: [num_experts, 2*intermediate_size, hidden_size//32] bf16
|
||||
gemm2_weights: [num_experts, hidden_size, intermediate_size//8] int32
|
||||
(checkpoint uint4b8 packed) or uint8 (already packed signed int4)
|
||||
gemm2_scales: [num_experts, hidden_size, intermediate_size//32] bf16
|
||||
|
||||
Returns:
|
||||
Dict with keys 'gemm1_weights', 'gemm1_scales', 'gemm2_weights',
|
||||
'gemm2_scales' containing shuffled/packed tensors ready for kernel
|
||||
"""
|
||||
from flashinfer import block_scale_interleave
|
||||
from flashinfer.fused_moe import (
|
||||
convert_to_block_layout,
|
||||
)
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
get_w2_permute_indices_with_cache,
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
reorder_w1w3_to_w3w1,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
)
|
||||
|
||||
device = gemm1_weights.device
|
||||
assert gemm1_weights.ndim == 3, (
|
||||
f"Expected a 3D gemm1_weights tensor, got {gemm1_weights.shape}"
|
||||
)
|
||||
assert gemm1_scales.ndim == 3, (
|
||||
f"Expected a 3D gemm1_scales tensor, got {gemm1_scales.shape}"
|
||||
)
|
||||
assert gemm2_weights.ndim == 3, (
|
||||
f"Expected a 3D gemm2_weights tensor, got {gemm2_weights.shape}"
|
||||
)
|
||||
assert gemm2_scales.ndim == 3, (
|
||||
f"Expected a 3D gemm2_scales tensor, got {gemm2_scales.shape}"
|
||||
)
|
||||
|
||||
# Convert checkpoint format (uint4b8 in int32) to signed int4
|
||||
# Checkpoint stores INT4 as unsigned [0, 15], kernel expects signed [-8, 7]
|
||||
if gemm1_weights.dtype == torch.int32 and gemm2_weights.dtype == torch.int32:
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(gemm1_weights)
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(gemm2_weights)
|
||||
|
||||
gemm1_weights, gemm1_scales = reorder_w1w3_to_w3w1(
|
||||
gemm1_weights, gemm1_scales, dim=-2
|
||||
)
|
||||
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
num_experts = gemm1_weights.shape[0]
|
||||
|
||||
# Convert quantized weights to proper formats -
|
||||
gemm1_weights_mxint4 = gemm1_weights.view(torch.uint8)
|
||||
assert gemm1_scales.dtype == torch.bfloat16
|
||||
gemm2_weights_mxint4 = gemm2_weights.view(torch.uint8)
|
||||
assert gemm2_scales.dtype == torch.bfloat16
|
||||
|
||||
epilogue_tile_m = 128
|
||||
gemm1_weights_mxint4_shuffled = []
|
||||
gemm1_scales_shuffled = []
|
||||
gemm2_weights_mxint4_shuffled = []
|
||||
gemm2_scales_shuffled = []
|
||||
|
||||
for i in range(num_experts):
|
||||
# Calculate the permute indices for the following:
|
||||
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||
# for both w3_w1 and w2 weights and scale factors
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
_cache_permute_indices,
|
||||
gemm1_weights_mxint4[i],
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled = gemm1_weights_mxint4[i][
|
||||
permute_indices.to(gemm1_weights.device)
|
||||
].contiguous()
|
||||
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
_cache_permute_indices,
|
||||
gemm1_scales[i],
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=32,
|
||||
).to(device)
|
||||
gemm1_scales_shuffled.append(
|
||||
block_scale_interleave(gemm1_scales[i][permute_sf_indices].contiguous())
|
||||
)
|
||||
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
gemm2_weights_mxint4[i],
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled = gemm2_weights_mxint4[i][
|
||||
permute_indices.to(gemm2_weights.device)
|
||||
].contiguous()
|
||||
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
gemm2_scales[i],
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
block_scale_interleave(
|
||||
gemm2_scales[i][permute_sf_indices.to(gemm2_scales.device)].contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
block_k = 128
|
||||
gemm1_weights_shuffled = convert_to_block_layout(
|
||||
gemm1_weights_shuffled.view(torch.uint8), block_k
|
||||
)
|
||||
gemm2_weights_shuffled = convert_to_block_layout(
|
||||
gemm2_weights_shuffled.view(torch.uint8), block_k
|
||||
)
|
||||
|
||||
gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled)
|
||||
gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled)
|
||||
|
||||
gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled)
|
||||
gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled)
|
||||
gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16)
|
||||
gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16)
|
||||
return {
|
||||
"gemm1_weights": gemm1_weights_mxint4_shuffled,
|
||||
"gemm1_scales": gemm1_scales_shuffled,
|
||||
"gemm2_weights": gemm2_weights_mxint4_shuffled,
|
||||
"gemm2_scales": gemm2_scales_shuffled,
|
||||
}
|
||||
|
||||
|
||||
def flashinfer_trtllm_mxint4_moe(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
w13_weight_packed: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_packed: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
intermediate_size_per_partition: int,
|
||||
local_num_experts: int,
|
||||
ep_rank: int = 0,
|
||||
num_expert_group: int | None = None,
|
||||
topk_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routing_method_type: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM MxInt4 MoE kernel.
|
||||
|
||||
Args:
|
||||
x: Input hidden states. dtype: bfloat16
|
||||
router_logits: Router logits for expert selection. dtype: bfloat16/float32
|
||||
w13_weight_packed: Packed gate+up weights. dtype: uint8
|
||||
w13_weight_scale: Scales for gate+up weights. dtype: bfloat16
|
||||
w2_weight_packed: Packed down weights. dtype: uint8
|
||||
w2_weight_scale: Scales for down weights. dtype: bfloat16
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
top_k: Number of experts to select per token
|
||||
intermediate_size_per_partition: Intermediate size per partition
|
||||
local_num_experts: Number of experts on this rank
|
||||
ep_rank: Expert parallelism rank (default: 0)
|
||||
num_expert_group: Number of expert groups (default: None -> 0)
|
||||
topk_group: Top-k within groups (default: None -> 0)
|
||||
e_score_correction_bias: Optional routing bias. dtype: bfloat16
|
||||
routing_method_type: FlashInfer RoutingMethodType enum value
|
||||
|
||||
Returns:
|
||||
Output tensor from MoE layer. dtype: same as x (bfloat16)
|
||||
"""
|
||||
from flashinfer import RoutingMethodType
|
||||
from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe
|
||||
|
||||
assert x.dtype == torch.bfloat16, f"x dtype must be bfloat16, got {x.dtype}"
|
||||
assert w13_weight_packed.dtype == torch.uint8, (
|
||||
f"w13_weight_packed dtype must be uint8, got {w13_weight_packed.dtype}"
|
||||
)
|
||||
assert w13_weight_scale.dtype == torch.bfloat16, (
|
||||
f"w13_weight_scale dtype must be bfloat16, got {w13_weight_scale.dtype}"
|
||||
)
|
||||
assert w2_weight_packed.dtype == torch.uint8, (
|
||||
f"w2_weight_packed dtype must be uint8, got {w2_weight_packed.dtype}"
|
||||
)
|
||||
assert w2_weight_scale.dtype == torch.bfloat16, (
|
||||
f"w2_weight_scale dtype must be bfloat16, got {w2_weight_scale.dtype}"
|
||||
)
|
||||
|
||||
routing_bias = None
|
||||
if e_score_correction_bias is not None:
|
||||
routing_bias = e_score_correction_bias.to(torch.bfloat16)
|
||||
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
out = trtllm_mxint4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=x,
|
||||
gemm1_weights=w13_weight_packed.data,
|
||||
gemm1_weights_scale=w13_weight_scale.data,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2_weight_packed.data,
|
||||
gemm2_weights_scale=w2_weight_scale.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
intermediate_size=intermediate_size_per_partition,
|
||||
local_expert_offset=ep_rank * local_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=routing_method_type,
|
||||
enable_pdl=None,
|
||||
output=None,
|
||||
tune_max_num_tokens=8192,
|
||||
).to(x.dtype)
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user