135 lines
4.2 KiB
Python
135 lines
4.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from enum import Enum
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class Mxfp8LinearBackend(Enum):
|
|
EMULATION = "emulation"
|
|
|
|
|
|
# MXFP8 constants
|
|
MXFP8_VALUE_DTYPE = torch.float8_e4m3fn
|
|
MXFP8_SCALE_DTYPE = torch.uint8
|
|
MXFP8_BLOCK_SIZE = 32
|
|
|
|
|
|
def _mxfp8_e4m3_quantize_impl(
|
|
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
|
|
|
|
x_q, x_scales = flashinfer_mxfp8_quantize(
|
|
x, is_sf_swizzled_layout=is_sf_swizzled_layout
|
|
)
|
|
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
|
|
x_scales = x_scales.view(x.size(0), -1)
|
|
return x_q, x_scales
|
|
|
|
|
|
def mxfp8_e4m3_quantize(
|
|
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout)
|
|
|
|
|
|
def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
|
|
"""Dequantize MXFP8 tensor to BF16."""
|
|
x_float = x.to(torch.float32)
|
|
|
|
num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE
|
|
x_blocked = x_float.view(*x.shape[:-1], num_blocks, MXFP8_BLOCK_SIZE)
|
|
|
|
descale = torch.exp2(scales.to(torch.float32) - 127.0)
|
|
|
|
dequantized = x_blocked * descale.unsqueeze(-1)
|
|
|
|
dequantized = dequantized.view(*x.shape)
|
|
|
|
return dequantized.to(torch.bfloat16)
|
|
|
|
|
|
def mxfp8_e4m3_quantize_fake(
|
|
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Fake implementation for torch.compile tracing."""
|
|
fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE)
|
|
|
|
block_size = MXFP8_BLOCK_SIZE
|
|
|
|
if x.ndim == 2:
|
|
M, N = x.shape
|
|
K = (N + block_size - 1) // block_size
|
|
if is_sf_swizzled_layout:
|
|
M_padded = ((M + 127) // 128) * 128
|
|
K_padded = ((K + 3) // 4) * 4
|
|
scales = torch.empty(
|
|
M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
|
|
)
|
|
else:
|
|
scales = torch.empty((M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
|
|
elif x.ndim == 3:
|
|
B, M, N = x.shape
|
|
K = (N + block_size - 1) // block_size
|
|
if is_sf_swizzled_layout:
|
|
M_padded = ((M + 127) // 128) * 128
|
|
K_padded = ((K + 3) // 4) * 4
|
|
scales = torch.empty(
|
|
B * M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
|
|
)
|
|
else:
|
|
scales = torch.empty((B, M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
|
|
else:
|
|
scale_shape = list(x.shape)
|
|
scale_shape[-1] = (x.shape[-1] + block_size - 1) // block_size
|
|
scales = torch.empty(scale_shape, dtype=MXFP8_SCALE_DTYPE, device=x.device)
|
|
|
|
return fp_data, scales
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="mxfp8_quantize",
|
|
op_func=_mxfp8_e4m3_quantize_impl,
|
|
fake_impl=mxfp8_e4m3_quantize_fake,
|
|
)
|
|
|
|
|
|
class Mxfp8LinearOp:
|
|
def __init__(self, backend: Mxfp8LinearBackend):
|
|
if backend not in Mxfp8LinearBackend:
|
|
raise ValueError(f"Unsupported backend: {backend}")
|
|
|
|
self.backend = backend
|
|
|
|
def apply(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
out_dtype: torch.dtype,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
# Validate weight_scale dtype and shape (must be 2D for TORCH backend)
|
|
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
|
|
raise ValueError(
|
|
f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
|
|
f"got {weight_scale.dtype}."
|
|
)
|
|
if weight_scale.ndim != 2:
|
|
raise ValueError(
|
|
f"TORCH backend requires 2D weight_scale, got {weight_scale.ndim}D. "
|
|
f"Ensure process_weights_after_loading was called."
|
|
)
|
|
|
|
weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale)
|
|
|
|
output = torch.nn.functional.linear(input, weight_bf16, bias)
|
|
return output.to(out_dtype)
|