[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
model = test_model_cls(hidden_size, token_num, dtype=dtype)
|
||||
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=self.vllm_config.model_config.dtype,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -49,6 +50,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -92,6 +94,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
weight_shape=(hidden_size, intermediate_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
|
||||
@@ -9,7 +9,7 @@ import vllm.config
|
||||
import vllm.ir.ops
|
||||
import vllm.plugins
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import (
|
||||
@@ -28,19 +28,23 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
AiterFp8BlockScaledMMKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
CutlassFp8BlockScaledMMKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
DeepGemmFp8BlockScaledMMKernel,
|
||||
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
TritonFp8BlockScaledMMKernel,
|
||||
_KernelT,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
create_fp8_quant_key,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
@@ -66,9 +70,12 @@ CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
# Blockwise group shapes
|
||||
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
|
||||
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
# ROCm kernels
|
||||
@@ -80,8 +87,8 @@ ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
KERNEL_GROUPSHAPE_COMBINATIONS = (
|
||||
@@ -100,8 +107,8 @@ AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
||||
# Blockwise (no kernel abstraction)
|
||||
(None, GroupShape(1, 128), True),
|
||||
# Blockwise
|
||||
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
|
||||
]
|
||||
|
||||
|
||||
@@ -110,8 +117,9 @@ class TestModel(torch.nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
force_kernel: type[_KernelT] | None,
|
||||
group_shape: GroupShape,
|
||||
dtype: torch.dtype,
|
||||
use_aiter_fusion: bool = False,
|
||||
use_aiter_quant: bool = False,
|
||||
*args,
|
||||
@@ -129,54 +137,42 @@ class TestModel(torch.nn.Module):
|
||||
is_blockwise = group_shape.is_per_group()
|
||||
|
||||
if is_blockwise:
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
block_size = group_shape.col
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
static=False, group_shape=group_shape
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestBlockFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=use_aiter_quant,
|
||||
transpose_weights=use_aiter_fusion,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.enable_quant_fp8_custom_op = (
|
||||
False
|
||||
if use_aiter_quant
|
||||
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(block_size, block_size)
|
||||
)
|
||||
|
||||
else:
|
||||
is_static = group_shape == GroupShape.PER_TENSOR
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
|
||||
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
is_static, group_shape=group_shape
|
||||
)
|
||||
self.weight_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=group_shape
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
# Enable aiter quantization if requested
|
||||
for layer in self.fp8_linear_layers:
|
||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
force_kernel=force_kernel,
|
||||
transpose_weights=use_aiter_fusion,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
# Enable aiter quantization if requested
|
||||
for layer in self.fp8_linear_layers:
|
||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
dtype=dtype,
|
||||
use_aiter_fusion=False,
|
||||
use_aiter_quant=False,
|
||||
)
|
||||
@@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
dtype=dtype,
|
||||
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
|
||||
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
|
||||
)
|
||||
|
||||
@@ -66,6 +66,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_qo_heads,
|
||||
@@ -155,6 +156,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
device=self.device,
|
||||
input_dtype=self.dtype,
|
||||
)
|
||||
|
||||
w = kwargs.get("w")
|
||||
|
||||
@@ -74,6 +74,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
# Create kv_b_proj (ColumnParallelLinear) on device.
|
||||
# Reuse weights from prior model instance when available, because
|
||||
@@ -190,6 +191,7 @@ class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
device=self.device,
|
||||
input_dtype=self.dtype,
|
||||
)
|
||||
|
||||
w = kwargs.get("w")
|
||||
|
||||
@@ -36,9 +36,9 @@ from vllm.model_executor.kernels.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
@@ -58,7 +58,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
|
||||
self,
|
||||
hidden_size: int,
|
||||
force_kernel: FP8ScaledMMLinearKernel,
|
||||
dtype: torch.dtype,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@@ -68,6 +72,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
force_kernel=force_kernel,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
@@ -137,14 +142,20 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, **kwargs):
|
||||
act_quant_key = kFp8Dynamic128Sym
|
||||
|
||||
def __init__(self, hidden_size: int, dtype: torch.dtype, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=GroupShape(1, 128),
|
||||
cutlass_block_fp8_supported=False,
|
||||
use_aiter_and_is_supported=True,
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(hidden_size, hidden_size)
|
||||
)
|
||||
|
||||
self.w8a8_block_fp8_linear = TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
activation_quant_key=self.act_quant_key,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
@@ -157,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
|
||||
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@@ -324,7 +335,9 @@ def test_fusion_silu_and_mul_quant(
|
||||
|
||||
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||
backend = TestBackend(*passes)
|
||||
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
|
||||
model = model_class(
|
||||
hidden_size=hidden_size, force_kernel=force_kernel, x=x, dtype=dtype
|
||||
)
|
||||
|
||||
# First dimension dynamic
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
Reference in New Issue
Block a user