[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)

Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
Maral
2026-04-09 08:50:39 +08:00
committed by GitHub
parent 8332078cfd
commit 2e9034c998
35 changed files with 1710 additions and 904 deletions

View File

@@ -43,12 +43,10 @@ from vllm.distributed import (
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.kernels.linear import (
FP8ScaledMMLinearKernel,
_KernelT,
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
)
from vllm.model_executor.model_loader import get_model_loader
@@ -1811,31 +1809,52 @@ class TestFP8Layer(torch.nn.Module):
weight_shape: tuple[int, int],
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
input_dtype: torch.dtype,
out_dtype: torch.dtype | None = None,
transpose_weights: bool = False,
device: torch.device | None = None,
force_kernel: FP8ScaledMMLinearKernel | None = None,
force_kernel: type[_KernelT] | None = None,
):
super().__init__()
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
is_static_activation_scale = activation_quant_key.scale.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
self.input_scale_ub = None
act_scale_desc = activation_quant_key.scale
weight_scale_desc = weight_quant_key.scale
is_block_wise = act_scale_desc.group_shape.is_per_group()
if is_block_wise:
block_size = weight_scale_desc.group_shape.col
weight_scale_shape = weight_shape[0] // block_size
self.weight_scale_inv = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
self.weight_scale = None
if transpose_weights:
self.weight = self.weight.t()
else:
per_tensor_weights = weight_scale_desc.group_shape.is_per_tensor()
is_static_activation_scale = act_scale_desc.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale_inv = None
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = (
torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
)
self.input_scale_ub = None
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
self.kernel = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
force_kernel=force_kernel,
)
@@ -1847,61 +1866,3 @@ class TestFP8Layer(torch.nn.Module):
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.kernel.apply_weights(self, y, bias)
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
# after refactoring W8A8BlockFp8LinearOp.
# https://github.com/vllm-project/vllm/issues/31818
class TestBlockFP8Layer:
"""
Test helper for blockwise FP8 linear operations. Creates random weights
and scales for W8A8BlockFp8LinearOp.
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
abstraction (ScaledMMLinearKernel) for blockwise quantization.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
group_shape: Blockwise quantization group shape.
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
use_aiter_and_is_supported: Whether to use aiter quantization ops.
transpose_weights: Whether to transpose weights after creation.
"""
def __init__(
self,
weight_shape: tuple[int, int],
group_shape: GroupShape,
cutlass_block_fp8_supported: bool = False,
use_aiter_and_is_supported: bool = False,
transpose_weights: bool = False,
):
weight_scale_shape = weight_shape[0] // group_shape[1]
self.weight_scale = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
if transpose_weights:
self.weight = self.weight.t()
self.linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def __call__(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.linear_op.apply(
input=y,
weight=self.weight,
weight_scale=self.weight_scale,
input_scale=self.input_scale,
bias=bias,
)
def is_quant_fp8_enabled(self) -> bool:
return self.linear_op.input_quant_op.enabled()