[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
@@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
return y
|
||||
|
||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
|
||||
return (torch.rand(num_tokens, hidden_size * 2),)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
if TEST_FP8 and do_fusion:
|
||||
@@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size), dtype=dtype)
|
||||
torch.empty((intermediate_size, hidden_size))
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
self.norm.weight = torch.nn.Parameter(
|
||||
torch.ones(intermediate_size, dtype=dtype)
|
||||
)
|
||||
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
|
||||
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
@@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
return norm_output, residual_output
|
||||
|
||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size))
|
||||
return (hidden_states, residual)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module):
|
||||
return q_rotated, k_rotated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64):
|
||||
dtype = torch.float16
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
q = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||
k = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||
q = torch.randn(num_tokens, head_dim)
|
||||
k = torch.randn(num_tokens, head_dim)
|
||||
return (positions, q, k)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
self.hidden_size = head_dim * num_heads
|
||||
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
|
||||
self.hidden_size, self.hidden_size * 3, bias=False
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
return qkv_updated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
|
||||
dtype = torch.float16
|
||||
hidden_size = head_dim * num_heads
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
return (positions, hidden_states)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@@ -211,48 +209,58 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||
def test_fix_functionalization(
|
||||
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["all"],
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
|
||||
),
|
||||
)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert RMSNorm.enabled()
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
for op in model.ops_not_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
for op in model.ops_not_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
|
||||
Reference in New Issue
Block a user