Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
131
tests/compile/test_fuse_act_padding.py
Normal file
131
tests/compile/test_fuse_act_padding.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.config
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
hidden_size: int,
|
||||
num_local_experts: int,
|
||||
x_pad_to_multiple: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.x_pad_to_multiple = x_pad_to_multiple
|
||||
self.pad_dim = x_pad_to_multiple - (hidden_size % x_pad_to_multiple)
|
||||
|
||||
self.norm = [RMSNorm(hidden_size, eps=1e-5) for _ in range(num_layers)]
|
||||
self.router = [
|
||||
torch.nn.Linear(hidden_size, num_local_experts) for _ in range(4)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
all_router_logits = []
|
||||
for layer in range(self.num_layers):
|
||||
x = x[:, : self.hidden_size]
|
||||
x, resid = self.norm[layer](x, resid)
|
||||
router_logits = rocm_unquantized_gemm(
|
||||
self, x, self.router[layer].weight, self.router[layer].bias
|
||||
)
|
||||
x = torch.nn.functional.pad(
|
||||
x, (0, self.pad_dim), mode="constant", value=0.0
|
||||
)
|
||||
all_router_logits.append(router_logits)
|
||||
|
||||
return x, resid, *all_router_logits
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
rocm_aiter_ops.get_rmsnorm_fused_add_op(),
|
||||
torch.ops.aten.constant_pad_nd,
|
||||
]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_layers", [3])
|
||||
@pytest.mark.parametrize("hidden_size", [2880])
|
||||
@pytest.mark.parametrize("num_local_experts", [128])
|
||||
@pytest.mark.parametrize("x_pad_to_multiple", [256])
|
||||
@pytest.mark.skipif(
|
||||
not is_aiter_found_and_supported(),
|
||||
reason="Only test on ROCm with AITER installed and supported",
|
||||
)
|
||||
def test_fuse_act_padding(
|
||||
dtype: torch.dtype,
|
||||
num_layers: int,
|
||||
hidden_size: int,
|
||||
num_local_experts: int,
|
||||
x_pad_to_multiple: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(fuse_act_padding=True, eliminate_noops=True),
|
||||
),
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
fusion_pass = RocmAiterTritonAddRMSNormPadFusionPass(vllm_config)
|
||||
passes = [
|
||||
NoOpEliminationPass(vllm_config),
|
||||
fusion_pass,
|
||||
PostCleanupPass(vllm_config),
|
||||
]
|
||||
backend = TestBackend(*passes)
|
||||
model = TestModel(num_layers, hidden_size, num_local_experts, x_pad_to_multiple)
|
||||
|
||||
x = torch.rand(1, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
outputs_unfused = model(x)
|
||||
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
outputs_fused = model_fused(x)
|
||||
|
||||
torch.testing.assert_close(outputs_unfused, outputs_fused)
|
||||
|
||||
assert fusion_pass.matched_count == num_layers
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
@@ -410,7 +410,7 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
|
||||
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormQuantFusionPass
|
||||
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
@@ -420,7 +420,7 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
|
||||
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
|
||||
fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
|
||||
|
||||
model = TestModel(
|
||||
hidden_size=hidden_size,
|
||||
|
||||
@@ -790,6 +790,41 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
|
||||
return x_fp8, out_bs
|
||||
|
||||
|
||||
def _rocm_aiter_triton_add_rmsnorm_pad_impl(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
residual: torch.Tensor,
|
||||
x_pad_to_multiple: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
|
||||
|
||||
return fused_add_rmsnorm_pad(
|
||||
x,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
residual,
|
||||
x_pad_to_multiple=x_pad_to_multiple,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_triton_add_rmsnorm_pad_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
residual: torch.Tensor,
|
||||
x_pad_to_multiple: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
M, N = x.shape
|
||||
if x_pad_to_multiple > 0:
|
||||
N_out = (N + x_pad_to_multiple - 1) // x_pad_to_multiple * x_pad_to_multiple
|
||||
else:
|
||||
N_out = N
|
||||
out = torch.empty((M, N_out), dtype=x.dtype, device=x.device)
|
||||
residual_out = torch.empty_like(residual)
|
||||
return out, residual_out
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
@@ -1108,6 +1143,13 @@ class rocm_aiter_ops:
|
||||
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_triton_add_rmsnorm_pad",
|
||||
op_func=_rocm_aiter_triton_add_rmsnorm_pad_impl,
|
||||
fake_impl=_rocm_aiter_triton_add_rmsnorm_pad_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_group_fp8_quant",
|
||||
op_func=_rocm_aiter_group_fp8_quant_impl,
|
||||
@@ -1175,6 +1217,10 @@ class rocm_aiter_ops:
|
||||
def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
||||
|
||||
@staticmethod
|
||||
def get_triton_add_rmsnorm_pad_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
|
||||
@@ -18,8 +18,9 @@ from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormFusionPass,
|
||||
RocmAiterRMSNormQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@@ -123,13 +124,16 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [
|
||||
RocmAiterRMSNormFusionPass(config),
|
||||
RocmAiterRMSNormQuantFusionPass(config),
|
||||
]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
|
||||
@@ -266,7 +266,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
|
||||
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
|
||||
into a fused rms_norm_quant op.
|
||||
@@ -399,3 +399,106 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
AiterSiluMulFp8GroupQuantPattern,
|
||||
]
|
||||
return VllmInductorPass.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AddAiterRMSNormPadPattern:
|
||||
"""
|
||||
This pattern replaces an aiter_rmsnorm_with_add & a pad op
|
||||
with a custom triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
hidden_size: int,
|
||||
x_pad_to_multiple: int,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.hidden_size = hidden_size
|
||||
self.x_pad_to_multiple = x_pad_to_multiple
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight, residual = self.rmsnorm_matcher.inputs()
|
||||
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
|
||||
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
|
||||
return [input, weight, residual, router_weight, router_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pad_size = self.x_pad_to_multiple - (
|
||||
self.hidden_size % self.x_pad_to_multiple
|
||||
)
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_rms, router_weight, router_bias
|
||||
)
|
||||
result = torch.nn.functional.pad(
|
||||
result_rms, (0, pad_size), mode="constant", value=0.0
|
||||
)
|
||||
return result, residual_out, router_logits
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
residual=residual,
|
||||
x_pad_to_multiple=self.x_pad_to_multiple,
|
||||
)
|
||||
result_padded = at[0]
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_padded[:, : self.hidden_size], router_weight, router_bias
|
||||
)
|
||||
residual_out = at[1]
|
||||
return result_padded, residual_out, router_logits
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass replaces an AITER CK RMSNorm + residual add and a pad op
|
||||
with an triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
|
||||
)
|
||||
|
||||
# gpt-oss has hidden size 2880
|
||||
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
|
||||
hidden_size = 2880
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for x_pad_to_multiple in [128, 256]:
|
||||
AddAiterRMSNormPadPattern(
|
||||
epsilon, hidden_size, x_pad_to_multiple
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)
|
||||
|
||||
@@ -126,6 +126,10 @@ class PassConfig:
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
|
||||
# ROCm/AITER specific fusions
|
||||
fuse_act_padding: bool = Field(default=None)
|
||||
"""Fuse the custom RMSNorm + padding ops."""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
vllm should use flashinfer fused allreduce. Specified as a
|
||||
@@ -194,6 +198,7 @@ class PassConfig:
|
||||
"enable_sp",
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"fuse_act_padding",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
@@ -222,12 +227,23 @@ class PassConfig:
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.fuse_act_padding:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + padding fusion might not work"
|
||||
)
|
||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
|
||||
logger.warning_once(
|
||||
"QK Norm + RoPE fusion enabled but the current platform is not "
|
||||
"CUDA or ROCm. The fusion will be disabled."
|
||||
)
|
||||
self.enable_qk_norm_rope_fusion = False
|
||||
if self.fuse_act_padding and not current_platform.is_rocm():
|
||||
logger.warning_once(
|
||||
"Padding fusion enabled but the current platform is not ROCm. "
|
||||
"The fusion will be disabled."
|
||||
)
|
||||
self.fuse_act_padding = False
|
||||
|
||||
|
||||
class DynamicShapesType(str, enum.Enum):
|
||||
|
||||
@@ -102,6 +102,18 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool:
|
||||
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||
|
||||
|
||||
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
|
||||
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
|
||||
|
||||
return (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
and cfg.model_config.get_hidden_size() == 2880
|
||||
)
|
||||
|
||||
|
||||
OPTIMIZATION_LEVEL_00 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
@@ -112,6 +124,7 @@ OPTIMIZATION_LEVEL_00 = {
|
||||
"fuse_attn_quant": False,
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
"fuse_act_padding": False,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -127,6 +140,7 @@ OPTIMIZATION_LEVEL_01 = {
|
||||
"fuse_attn_quant": False,
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -142,6 +156,7 @@ OPTIMIZATION_LEVEL_02 = {
|
||||
"fuse_attn_quant": IS_QUANTIZED,
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -157,6 +172,7 @@ OPTIMIZATION_LEVEL_03 = {
|
||||
"fuse_attn_quant": IS_QUANTIZED,
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
|
||||
@@ -137,6 +137,11 @@ def rocm_unquantized_gemm_impl(
|
||||
|
||||
import math
|
||||
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
use_skinny_reduce_counting = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx950()
|
||||
@@ -155,11 +160,6 @@ def rocm_unquantized_gemm_impl(
|
||||
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
use_skinny = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx9()
|
||||
|
||||
@@ -187,7 +187,7 @@ class MLPBlock(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||
|
||||
Reference in New Issue
Block a user