[FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (#14390)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-03-07 00:20:16 -05:00
committed by GitHub
parent dae6896977
commit e1744502c2
11 changed files with 257 additions and 231 deletions

View File

@@ -15,7 +15,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported,
cutlass_fp8_supported)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@@ -32,6 +33,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
@@ -49,6 +52,7 @@ def apply_w8a8_block_fp8_linear(
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0)
if current_platform.is_rocm():
# TODO this is never used, as cutlass_block_fp8_supported is False
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
@@ -104,43 +108,55 @@ direct_register_custom_op(
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
def apply_fp8_linear_generic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
# TODO(luka): unify this better
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearGenericOp:
weight_group_shape = _normalize_quant_group_shape(\
weight, weight_group_shape)
input_group_shape = _normalize_quant_group_shape(input, input_group_shape)
def __init__(
self,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(),
):
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_supported)
def is_dim_blocked(dim, shape, group_shape):
return group_shape < shape[dim] and group_shape > 1
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(
input,
weight,
list(weight_group_shape),
weight_scale,
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
else:
# Despite having linear in the it doesn't conform to
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
# so we explicitly transpose the weight matrix here
return apply_fp8_linear(input, weight.T, weight_scale.T,
cutlass_fp8_supported=cutlass_fp8_supported,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))
weight_group_shape = _normalize_quant_group_shape( \
weight, weight_group_shape)
input_group_shape = _normalize_quant_group_shape(
input, input_group_shape)
def is_dim_blocked(dim, shape, group_shape):
return group_shape < shape[dim] and group_shape > 1
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(
input,
weight,
list(weight_group_shape),
weight_scale,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported)
else:
# Despite having linear in the name it doesn't conform to
# `torch.nn.functional.linear` which is defined as
# `input @ weight.T` so we explicitly transpose the weight matrix
return self.fp8_linear.apply(input, weight.T, weight_scale.T,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))
def input_to_float8(