[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,8 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -42,9 +44,7 @@ prompts = [
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
@@ -95,13 +95,11 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
@@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model(
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
)
|
||||
) # NoOp needed for fusion
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
|
||||
with set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
passes_for_backend: list[VllmInductorPass] = [
|
||||
noop_pass,
|
||||
sequence_parallelism_pass,
|
||||
]
|
||||
|
||||
if enable_fusion:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
if enable_fusion:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
|
||||
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
|
||||
model = test_model_cls(hidden_size, hidden_size * 2)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
|
||||
Reference in New Issue
Block a user