[ROCm][FEAT] Support AITER RMSNorm quantization fusion pass (#26575)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
vllmellm
2025-12-23 11:07:54 +01:00
committed by GitHub
parent 6b16fff01b
commit f32cfd7d97
5 changed files with 662 additions and 218 deletions

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
@@ -53,37 +52,61 @@ class TestModel(torch.nn.Module):
hidden_size: int,
eps: float,
group_shape: GroupShape,
cuda_force_torch: bool,
use_aiter: bool = False,
cuda_force_torch: bool = False,
use_aiter_quant_op: bool = True,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.use_aiter = use_aiter
self.use_aiter_quant_op = use_aiter_quant_op
self.cuda_force_torch = cuda_force_torch
self.group_shape = group_shape
self.enable_quant_fp8_custom_op = None # Will be set later if applicable
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
if group_shape.is_per_group():
self.wscale = [
torch.rand(
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
dtype=torch.float32,
)
for _ in range(3)
]
else:
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
static = group_shape == GroupShape.PER_TENSOR
# Setup quantization scale descriptor
static = group_shape == GroupShape.PER_TENSOR and not use_aiter
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
# Setup scales
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else:
self.scale = [None for _ in range(3)]
# Setup weights
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
]
if not group_shape.is_per_group():
if not group_shape.is_per_group() or use_aiter:
self.w = [self.w[0].t() for _ in range(3)]
# Setup weight scales
if group_shape.is_per_group():
scale_size = (
(hidden_size + 128 - 1) // 128
if use_aiter
else hidden_size // group_shape[1]
)
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
else:
wscale_shape = (1,)
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
# Setup FP8 linear operation
is_per_group = group_shape.is_per_group()
if is_per_group and use_aiter:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=group_shape,
use_aiter_and_is_supported=use_aiter_quant_op,
)
# AITER blockwise doesn't use enable_quant_fp8_custom_op
elif is_per_group:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
@@ -91,6 +114,13 @@ class TestModel(torch.nn.Module):
use_aiter_and_is_supported=False,
)
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
elif use_aiter:
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=group_shape,
)
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
else:
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
@@ -100,7 +130,6 @@ class TestModel(torch.nn.Module):
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.group_shape = group_shape
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
@@ -126,19 +155,49 @@ class TestModel(torch.nn.Module):
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_before(self):
if (
self.use_aiter
and self.group_shape.is_per_group()
and current_platform.is_fp8_fnuz()
):
return [rocm_aiter_ops.get_group_quant_op()]
if self.use_aiter and self.group_shape.is_per_group():
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
if self.use_aiter and self.use_aiter_quant_op:
return [rocm_aiter_ops.get_per_token_quant_op()]
if self.use_aiter:
return [QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op:
return [QUANT_OPS[self.quant_key]]
return [torch.ops.aten.reciprocal]
def ops_in_model_after(self):
if self.use_aiter and self.group_shape.is_per_group():
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSFp8GroupQuantPattern,
AiterRMSFp8GroupQuantPattern,
)
return [
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
AiterRMSFp8GroupQuantPattern.FUSED_OP,
]
if self.use_aiter:
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSNormDynamicQuantPattern,
)
return [
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
AiterRMSNormDynamicQuantPattern.FUSED_OP,
]
return [
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
@@ -155,67 +214,45 @@ GROUP_SHAPES = [
]
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
def _run_fusion_test(
model,
fusion_pass,
vllm_config,
dtype,
hidden_size,
num_tokens,
):
"""Helper function for common fusion test logic.
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]
Must be called within vllm_config context.
"""
noop_pass = NoOpEliminationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4
assert fusion_pass.matched_count == 3
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())
def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
return backend, backend2
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize(
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
@@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant(
num_tokens,
eps,
group_shape,
model_class,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
# Skip test for 64-bit group shape when running with cutlass or deepgemm
if group_shape == GroupShape(1, 64) and (
cutlass_block_fp8_supported() or is_deep_gemm_supported()
):
@@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant(
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
@@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant(
),
),
)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
# Setup device before model creation
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = model_class(
fusion_pass = RMSNormQuantFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
use_aiter=False,
cuda_force_torch=cuda_force_torch,
)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 3
backend.check_before_ops(model.ops_in_model_before())
backend, _ = _run_fusion_test(
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
)
backend.check_before_ops(
model.ops_in_model_before_partial(), fully_replaced=False
)
backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if (
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2
GROUP_SHAPE_QUANT_OPS_MATCHS = [
(GroupShape.PER_TOKEN, True),
(GroupShape.PER_TOKEN, False),
(GroupShape(1, 128), True),
]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
)
@pytest.mark.skipif(
(not current_platform.is_rocm() or not IS_AITER_FOUND),
reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
dtype: torch.dtype,
hidden_size: int,
num_tokens: int,
eps: float,
group_shape: GroupShape,
use_aiter_quant_op: bool,
monkeypatch: pytest.MonkeyPatch,
):
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
),
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
use_aiter=True,
use_aiter_quant_op=use_aiter_quant_op,
)
_run_fusion_test(
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
)