[Hardware/NVIDIA/Kernel] Enable nvidia/DeepSeek-R1-FP4 Model (#16362)

This commit is contained in:
Pavani Majety
2025-05-09 16:24:41 -07:00
committed by GitHub
parent 3b602cdea7
commit 0c0fdae84f
16 changed files with 1994 additions and 112 deletions

View File

@@ -745,10 +745,11 @@ def get_cutlass_moe_mm_data(
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2,
input_permutation, output_permutation,
num_experts, n, k)
return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2,
input_permutation,
output_permutation,
num_experts, n, k)
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
@@ -767,9 +768,41 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
MMs used in the fused MoE operation.
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
"""
torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales,
b_scales, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides)
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides,
c_strides)
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor,
alphas: torch.Tensor, problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
out_dtype: torch.dtype, device: torch.device):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_tensors.shape[0]
n = b_tensors.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
b_scales, alphas, problem_sizes,
expert_offsets, sf_offsets)
return c.to(out_dtype)
# aqlm
@@ -960,6 +993,57 @@ def scaled_fp4_quant(
return output, output_scale
def scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
MAX_TOKENS_PER_EXPERT: int = 163840,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2, (
f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')
input_tensor = input_tensor[
expert_map] if expert_map is not None else input_tensor
m_numtopk, k = input_tensor.shape
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for"
f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}")
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(m_numtopk,
k // 2,
device=input_tensor.device,
dtype=torch.uint8)
output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device)
torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor,
input_global_scale, expert_offsets,
blockscale_offsets)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# fp8
def scaled_fp8_quant(
input: torch.Tensor,

View File

@@ -36,7 +36,7 @@ if HAS_TRITON:
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
@@ -48,4 +48,5 @@ if HAS_TRITON:
"get_config_file_name",
"grouped_topk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
]

View File

@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
""" CUTLASS based Fused MoE kernels."""
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.scalar_type import scalar_types
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
@@ -178,3 +179,126 @@ def cutlass_moe_fp8(
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
MAX_TOKENS_PER_EXPERT = 65536
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
device: torch.device):
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts, dtype: int
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3
and w2_blockscale.ndim
== 3), ("All Weights must be of rank 3 for cutlass_moe_fp4")
m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
" between weights.")
assert (k_a // 2 == half_k_w1
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
"expected `n`")
assert (m == m_a), "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
== m), ("topk must be provided for each row of a")
assert (m <= MAX_TOKENS_PER_EXPERT), (
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m = {m}")
out_dtype = a.dtype
num_topk = topk_ids.shape[1]
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, e, n, k)
tokens_per_expert = problem_sizes1[:, 0]
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
a,
a1_gscale,
expert_offsets,
blockscale_offsets,
num_topk,
expert_map=a_map,
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
expert_offsets[:-1], blockscale_offsets[:-1],
out_dtype, device)
del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
device=device,
dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
intermediate,
a2_gscale,
expert_offsets,
blockscale_offsets,
num_topk,
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1], out_dtype, device)
del int_fp4, int_blockscale
out = (c2[c_map].view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
return out.to(dtype=out_dtype)

View File

@@ -643,7 +643,7 @@ class FusedMoE(torch.nn.Module):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
quant_method_name = self.quant_method.__class__.__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
@@ -697,8 +697,9 @@ class FusedMoE(torch.nn.Module):
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
if ("compressed" in quant_method_name.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
@@ -718,6 +719,22 @@ class FusedMoE(torch.nn.Module):
tp_rank=self.tp_rank)
return
if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name
or 'input_scale' in weight_name):
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return
# Case weight scales, zero_points and offset
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch.nn import Module
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -210,25 +212,37 @@ class ModelOptNvFp4Config(QuantizationConfig):
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
if ("group_size" and "kv_cache_quant_algo"
and "exclude_modules") not in quant_config:
raise ValueError("NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json")
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, exclude_modules: List):
import re
for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules):
if (is_layer_skipped(prefix, self.exclude_modules)
or self.is_layer_excluded(prefix, self.exclude_modules)):
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoE(self)
return None
@@ -409,3 +423,235 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
if bias is not None:
out = out + bias
return out.view(*output_shape)
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization.
Args:
quant_config: NVFP4 Quant Config
"""
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError("NVFP4 quantization was selected, "
" dynamic quantization is not supported.")
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
# GEMM 1
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w13_weight", w13_weight)
# GEMM 2
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=weight_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w2_weight", w2_weight)
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition //
self.quant_config.group_size,
dtype=weight_scale_dtype),
input_dim=1,
output_dim=2,
weight_loader=weight_loader)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
w2_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
w13_input_scale = PerTensorScaleParameter(data=torch.empty(
num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(data=torch.empty(
num_experts, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
assert torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
requires_grad=False)
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False)
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w13_blockscale_swizzled = self.swizzle_blockscale(
layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
requires_grad=False)
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False)
# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE.")
assert expert_map is None, ("Expert Parallelism /expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(a=x,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device).to(x.dtype)