[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
127
tests/utils.py
127
tests/utils.py
@@ -42,6 +42,17 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
@@ -50,6 +61,8 @@ from vllm.utils.mem_constants import GB_bytes
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from amdsmi import (
|
||||
amdsmi_get_gpu_vram_usage,
|
||||
@@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]):
|
||||
for element in itertools.product(*iterables):
|
||||
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
|
||||
yield tuple(itertools.chain(*normalized))
|
||||
|
||||
|
||||
class TestFP8Layer(torch.nn.Module):
|
||||
"""
|
||||
Test helper for FP8 linear operations. Creates random weights and scales
|
||||
based on quantization configuration.
|
||||
|
||||
Args:
|
||||
weight_shape: Shape of the weight tensor (out_features, in_features).
|
||||
activation_quant_key: Activation quantization configuration.
|
||||
weight_quant_key: Weight quantization configuration.
|
||||
out_dtype: Output dtype. Defaults to current default dtype.
|
||||
force_kernel: Optional kernel to force use of specific implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_shape: tuple[int, int],
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
is_static_activation_scale = activation_quant_key.scale.static
|
||||
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
|
||||
|
||||
self.weight_scale = torch.rand(
|
||||
weight_scale_shape, dtype=torch.float32, device=device
|
||||
)
|
||||
self.input_scale = (
|
||||
torch.rand(1, dtype=torch.float32, device=device)
|
||||
if is_static_activation_scale
|
||||
else None
|
||||
)
|
||||
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
|
||||
self.input_scale_ub = None
|
||||
|
||||
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
|
||||
|
||||
self.kernel = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=out_dtype,
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
return self.kernel.quant_fp8.enabled()
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(self, y, bias)
|
||||
|
||||
|
||||
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
|
||||
# after refactoring W8A8BlockFp8LinearOp.
|
||||
# https://github.com/vllm-project/vllm/issues/31818
|
||||
class TestBlockFP8Layer:
|
||||
"""
|
||||
Test helper for blockwise FP8 linear operations. Creates random weights
|
||||
and scales for W8A8BlockFp8LinearOp.
|
||||
|
||||
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
|
||||
abstraction (ScaledMMLinearKernel) for blockwise quantization.
|
||||
|
||||
Args:
|
||||
weight_shape: Shape of the weight tensor (out_features, in_features).
|
||||
group_shape: Blockwise quantization group shape.
|
||||
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
|
||||
use_aiter_and_is_supported: Whether to use aiter quantization ops.
|
||||
transpose_weights: Whether to transpose weights after creation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = False,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
transpose_weights: bool = False,
|
||||
):
|
||||
weight_scale_shape = weight_shape[0] // group_shape[1]
|
||||
self.weight_scale = torch.rand(
|
||||
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
|
||||
)
|
||||
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
|
||||
self.input_scale = None
|
||||
if transpose_weights:
|
||||
self.weight = self.weight.t()
|
||||
|
||||
self.linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.linear_op.apply(
|
||||
input=y,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
input_scale=self.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
return self.linear_op.input_quant_op.enabled()
|
||||
|
||||
Reference in New Issue
Block a user