[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
113
tests/utils.py
113
tests/utils.py
@@ -43,12 +43,10 @@ from vllm.distributed import (
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
_KernelT,
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
@@ -1811,31 +1809,52 @@ class TestFP8Layer(torch.nn.Module):
|
||||
weight_shape: tuple[int, int],
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
input_dtype: torch.dtype,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
transpose_weights: bool = False,
|
||||
device: torch.device | None = None,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None = None,
|
||||
force_kernel: type[_KernelT] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
is_static_activation_scale = activation_quant_key.scale.static
|
||||
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
|
||||
|
||||
self.weight_scale = torch.rand(
|
||||
weight_scale_shape, dtype=torch.float32, device=device
|
||||
)
|
||||
self.input_scale = (
|
||||
torch.rand(1, dtype=torch.float32, device=device)
|
||||
if is_static_activation_scale
|
||||
else None
|
||||
)
|
||||
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
|
||||
self.input_scale_ub = None
|
||||
act_scale_desc = activation_quant_key.scale
|
||||
weight_scale_desc = weight_quant_key.scale
|
||||
is_block_wise = act_scale_desc.group_shape.is_per_group()
|
||||
if is_block_wise:
|
||||
block_size = weight_scale_desc.group_shape.col
|
||||
weight_scale_shape = weight_shape[0] // block_size
|
||||
self.weight_scale_inv = torch.rand(
|
||||
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
|
||||
)
|
||||
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
|
||||
self.input_scale = None
|
||||
self.weight_scale = None
|
||||
if transpose_weights:
|
||||
self.weight = self.weight.t()
|
||||
else:
|
||||
per_tensor_weights = weight_scale_desc.group_shape.is_per_tensor()
|
||||
is_static_activation_scale = act_scale_desc.static
|
||||
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
|
||||
self.weight_scale_inv = None
|
||||
self.weight_scale = torch.rand(
|
||||
weight_scale_shape, dtype=torch.float32, device=device
|
||||
)
|
||||
self.input_scale = (
|
||||
torch.rand(1, dtype=torch.float32, device=device)
|
||||
if is_static_activation_scale
|
||||
else None
|
||||
)
|
||||
self.weight = (
|
||||
torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
|
||||
)
|
||||
self.input_scale_ub = None
|
||||
|
||||
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
|
||||
|
||||
self.kernel = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
weight_shape=weight_shape,
|
||||
input_dtype=input_dtype,
|
||||
out_dtype=out_dtype,
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
@@ -1847,61 +1866,3 @@ class TestFP8Layer(torch.nn.Module):
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(self, y, bias)
|
||||
|
||||
|
||||
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
|
||||
# after refactoring W8A8BlockFp8LinearOp.
|
||||
# https://github.com/vllm-project/vllm/issues/31818
|
||||
class TestBlockFP8Layer:
|
||||
"""
|
||||
Test helper for blockwise FP8 linear operations. Creates random weights
|
||||
and scales for W8A8BlockFp8LinearOp.
|
||||
|
||||
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
|
||||
abstraction (ScaledMMLinearKernel) for blockwise quantization.
|
||||
|
||||
Args:
|
||||
weight_shape: Shape of the weight tensor (out_features, in_features).
|
||||
group_shape: Blockwise quantization group shape.
|
||||
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
|
||||
use_aiter_and_is_supported: Whether to use aiter quantization ops.
|
||||
transpose_weights: Whether to transpose weights after creation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = False,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
transpose_weights: bool = False,
|
||||
):
|
||||
weight_scale_shape = weight_shape[0] // group_shape[1]
|
||||
self.weight_scale = torch.rand(
|
||||
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
|
||||
)
|
||||
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
|
||||
self.input_scale = None
|
||||
if transpose_weights:
|
||||
self.weight = self.weight.t()
|
||||
|
||||
self.linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.linear_op.apply(
|
||||
input=y,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
input_scale=self.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
return self.linear_op.input_quant_op.enabled()
|
||||
|
||||
Reference in New Issue
Block a user