[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)

Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
Maral
2026-04-09 08:50:39 +08:00
committed by GitHub
parent 8332078cfd
commit 2e9034c998
35 changed files with 1710 additions and 904 deletions

View File

@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=dtype,
)
for i in range(3)
]
@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
model = test_model_cls(hidden_size, token_num, dtype=dtype)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

View File

@@ -109,6 +109,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=self.vllm_config.model_config.dtype,
)
for i in range(3)
]

View File

@@ -23,6 +23,7 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
@@ -49,6 +50,7 @@ class TestSiluMul(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)
def forward(self, x):
@@ -92,6 +94,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
weight_shape=(hidden_size, intermediate_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)
def forward(self, hidden_states, residual):

View File

@@ -9,7 +9,7 @@ import vllm.config
import vllm.ir.ops
import vllm.plugins
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
from tests.utils import TestFP8Layer
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
@@ -28,19 +28,23 @@ from vllm.config import (
VllmConfig,
)
from vllm.model_executor.kernels.linear import (
AiterFp8BlockScaledMMKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
CutlassFp8BlockScaledMMKernel,
CutlassFP8ScaledMMLinearKernel,
DeepGemmFp8BlockScaledMMKernel,
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
FlashInferFP8ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
TritonFp8BlockScaledMMKernel,
_KernelT,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
create_fp8_quant_key,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
@@ -66,9 +70,12 @@ CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
# Blockwise group shapes
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]
# ROCm kernels
@@ -80,8 +87,8 @@ ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]
KERNEL_GROUPSHAPE_COMBINATIONS = (
@@ -100,8 +107,8 @@ AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Blockwise (no kernel abstraction)
(None, GroupShape(1, 128), True),
# Blockwise
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
]
@@ -110,8 +117,9 @@ class TestModel(torch.nn.Module):
self,
hidden_size: int,
eps: float,
force_kernel: FP8ScaledMMLinearKernel | None,
force_kernel: type[_KernelT] | None,
group_shape: GroupShape,
dtype: torch.dtype,
use_aiter_fusion: bool = False,
use_aiter_quant: bool = False,
*args,
@@ -129,54 +137,42 @@ class TestModel(torch.nn.Module):
is_blockwise = group_shape.is_per_group()
if is_blockwise:
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
block_size = group_shape.col
self.activation_quant_key = create_fp8_quant_key(
static=False, group_shape=group_shape
)
self.fp8_linear_layers = [
TestBlockFP8Layer(
weight_shape=(hidden_size, hidden_size),
group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=use_aiter_quant,
transpose_weights=use_aiter_fusion,
)
for _ in range(3)
]
self.enable_quant_fp8_custom_op = (
False
if use_aiter_quant
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(block_size, block_size)
)
else:
is_static = group_shape == GroupShape.PER_TENSOR
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
self.activation_quant_key = create_fp8_quant_key(
is_static, group_shape=group_shape
)
self.weight_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=group_shape
)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
)
for _ in range(3)
]
# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
transpose_weights=use_aiter_fusion,
input_dtype=dtype,
)
for _ in range(3)
]
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
@@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=False,
use_aiter_quant=False,
)
@@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
)

View File

@@ -66,6 +66,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.attn = Attention(
num_heads=self.num_qo_heads,
@@ -155,6 +156,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)
w = kwargs.get("w")

View File

@@ -74,6 +74,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
# Create kv_b_proj (ColumnParallelLinear) on device.
# Reuse weights from prior model instance when available, because
@@ -190,6 +191,7 @@ class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)
w = kwargs.get("w")

View File

@@ -36,9 +36,9 @@ from vllm.model_executor.kernels.linear import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
@@ -58,7 +58,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
self,
hidden_size: int,
force_kernel: FP8ScaledMMLinearKernel,
dtype: torch.dtype,
**kwargs,
):
super().__init__()
self.silu_and_mul = SiluAndMul()
@@ -68,6 +72,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
force_kernel=force_kernel,
input_dtype=dtype,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
@@ -137,14 +142,20 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
act_quant_key = kFp8Dynamic128Sym
def __init__(self, hidden_size: int, dtype: torch.dtype, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(hidden_size, hidden_size)
)
self.w8a8_block_fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
weight_quant_key=self.weight_quant_key,
activation_quant_key=self.act_quant_key,
input_dtype=dtype,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
@@ -157,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
@@ -324,7 +335,9 @@ def test_fusion_silu_and_mul_quant(
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
model = model_class(
hidden_size=hidden_size, force_kernel=force_kernel, x=x, dtype=dtype
)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)