[Refactor] Refactor MOE NVFP4 Code Base: ModelOpt + Compressed Tensor (#21631)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
|
||||
@@ -592,3 +594,56 @@ def awq_pack(
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||
layout expected by the CUTLASS / FlashInfer kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale: torch.Tensor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The swizzled tensor with the same logical shape as *scale*.
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format.")
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
scale = scale.unsqueeze(0) # (1, M, K)
|
||||
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||
|
||||
B, M, K = scale.shape
|
||||
|
||||
def _round_up(x: int, m: int) -> int:
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
M_padded = _round_up(M, 128)
|
||||
K_padded = _round_up(K, 4)
|
||||
|
||||
padded = torch.zeros((B, M_padded, K_padded),
|
||||
dtype=scale.dtype,
|
||||
device=scale.device)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M, K)
|
||||
return swizzled.reshape(B, M, K)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
Reference in New Issue
Block a user