[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import vllm.config
|
||||
import vllm.ir.ops
|
||||
import vllm.plugins
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import (
|
||||
@@ -28,19 +28,23 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
AiterFp8BlockScaledMMKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
CutlassFp8BlockScaledMMKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
DeepGemmFp8BlockScaledMMKernel,
|
||||
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
TritonFp8BlockScaledMMKernel,
|
||||
_KernelT,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
create_fp8_quant_key,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
@@ -66,9 +70,12 @@ CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
# Blockwise group shapes
|
||||
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
|
||||
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
# ROCm kernels
|
||||
@@ -80,8 +87,8 @@ ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
KERNEL_GROUPSHAPE_COMBINATIONS = (
|
||||
@@ -100,8 +107,8 @@ AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
||||
# Blockwise (no kernel abstraction)
|
||||
(None, GroupShape(1, 128), True),
|
||||
# Blockwise
|
||||
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
|
||||
]
|
||||
|
||||
|
||||
@@ -110,8 +117,9 @@ class TestModel(torch.nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
force_kernel: type[_KernelT] | None,
|
||||
group_shape: GroupShape,
|
||||
dtype: torch.dtype,
|
||||
use_aiter_fusion: bool = False,
|
||||
use_aiter_quant: bool = False,
|
||||
*args,
|
||||
@@ -129,54 +137,42 @@ class TestModel(torch.nn.Module):
|
||||
is_blockwise = group_shape.is_per_group()
|
||||
|
||||
if is_blockwise:
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
block_size = group_shape.col
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
static=False, group_shape=group_shape
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestBlockFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=use_aiter_quant,
|
||||
transpose_weights=use_aiter_fusion,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.enable_quant_fp8_custom_op = (
|
||||
False
|
||||
if use_aiter_quant
|
||||
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(block_size, block_size)
|
||||
)
|
||||
|
||||
else:
|
||||
is_static = group_shape == GroupShape.PER_TENSOR
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
|
||||
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
is_static, group_shape=group_shape
|
||||
)
|
||||
self.weight_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=group_shape
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
# Enable aiter quantization if requested
|
||||
for layer in self.fp8_linear_layers:
|
||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
force_kernel=force_kernel,
|
||||
transpose_weights=use_aiter_fusion,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
# Enable aiter quantization if requested
|
||||
for layer in self.fp8_linear_layers:
|
||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
dtype=dtype,
|
||||
use_aiter_fusion=False,
|
||||
use_aiter_quant=False,
|
||||
)
|
||||
@@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
dtype=dtype,
|
||||
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
|
||||
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user