[Refactor] Refactor MOE NVFP4 Code Base: ModelOpt + Compressed Tensor (#21631)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-27 08:25:21 -04:00
committed by GitHub
parent 3d847a3125
commit bda9d0535f
6 changed files with 75 additions and 106 deletions

View File

@@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -592,3 +594,56 @@ def awq_pack(
q_w = q_w.reshape((-1, size_n)).contiguous()
return pack_cols(q_w, num_bits, size_k, size_n)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert scale.dtype == torch.float8_e4m3fn, (
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format.")
scale_ndim = scale.ndim
if scale_ndim == 2:
scale = scale.unsqueeze(0) # (1, M, K)
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
B, M, K = scale.shape
def _round_up(x: int, m: int) -> int:
return (x + m - 1) // m * m
M_padded = _round_up(M, 128)
K_padded = _round_up(K, 4)
padded = torch.zeros((B, M_padded, K_padded),
dtype=scale.dtype,
device=scale.device)
padded[:B, :M, :K] = scale
# Reshape / permute to the layout required by the kernel.
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M, K)
return swizzled.reshape(B, M, K)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)