[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2025-10-17 10:10:23 -04:00
committed by GitHub
parent be429d0cfd
commit bd7157a071
28 changed files with 1519 additions and 721 deletions

View File

@@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
@@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module):
return y
def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
return (torch.rand(num_tokens, hidden_size * 2),)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
@@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
dtype = torch.float16 if TEST_FP8 else torch.float32
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype)
torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(
torch.ones(intermediate_size, dtype=dtype)
)
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
torch.nn.init.normal_(self.gate_proj, std=0.02)
@@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
return norm_output, residual_output
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
residual = torch.randn((batch_size * seq_len, hidden_size))
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
@@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module):
return q_rotated, k_rotated
def example_inputs(self, num_tokens=32, head_dim=64):
dtype = torch.float16
positions = torch.arange(num_tokens, dtype=torch.long)
q = torch.randn(num_tokens, head_dim, dtype=dtype)
k = torch.randn(num_tokens, head_dim, dtype=dtype)
q = torch.randn(num_tokens, head_dim)
k = torch.randn(num_tokens, head_dim)
return (positions, q, k)
def ops_in_model(self, do_fusion):
@@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
self.hidden_size, self.hidden_size * 3, bias=False
)
self.rotary_emb = get_rope(
@@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
return qkv_updated
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
dtype = torch.float16
hidden_size = head_dim * num_heads
positions = torch.arange(num_tokens, dtype=torch.long)
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
hidden_states = torch.randn(num_tokens, hidden_size)
return (positions, hidden_states)
def ops_in_model(self, do_fusion):
@@ -211,48 +209,58 @@ MODELS = [
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
def test_fix_functionalization(
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
),
)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
with set_current_vllm_config(vllm_config):
assert RMSNorm.enabled()
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())