[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2025-10-17 10:10:23 -04:00
committed by GitHub
parent be429d0cfd
commit bd7157a071
28 changed files with 1519 additions and 721 deletions

View File

@@ -18,6 +18,8 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
@@ -42,9 +44,7 @@ prompts = [
class TestModel(torch.nn.Module):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
@@ -95,13 +95,11 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = vllm_config
self.vllm_config = get_current_vllm_config()
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)), requires_grad=False
)
@@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
)
) # NoOp needed for fusion
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(
model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
vllm_config = VllmConfig(
model_config=model_config,
device_config=device_config,
compilation_config=compilation_config,
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
with set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
passes_for_backend: list[VllmInductorPass] = [
noop_pass,
sequence_parallelism_pass,
]
if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
passes_for_backend.append(cleanup_pass)
passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
model = test_model_cls(hidden_size, hidden_size * 2)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1
assert sequence_parallelism_pass.matched_count == 1
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after())
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after())
# check if the functionalization pass is applied
for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
# check if the functionalization pass is applied
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())