[Feature]Add async tensor parallelism using compilation pass (#17882)

Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
cascade
2025-05-23 01:03:34 -07:00
committed by GitHub
parent 4c611348a7
commit 71ea614d4a
11 changed files with 472 additions and 56 deletions

View File

@@ -5,9 +5,7 @@ import torch
import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn,
find_specified_fn_maybe, is_func)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
@@ -21,17 +19,6 @@ from vllm.utils import update_environment_variables
from ..utils import multi_gpu_test
from .backend import TestBackend
OPS_IN_MODEL_BEFORE = [
torch.ops.vllm.all_reduce.default,
]
OPS_IN_MODEL_AFTER = [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default,
]
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
prompts = [
"Hello, my name is",
"The president of the United States is",
@@ -78,6 +65,18 @@ class TestModel(torch.nn.Module):
return norm_output, residual_output
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default
]
def ops_in_model(self):
return [torch.ops._C.fused_add_rms_norm.default]
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("batch_size", [8])
@@ -156,26 +155,16 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
# Check substitution worked
pre_nodes = backend_no_func.graph_pre_pass.nodes
post_nodes = backend_no_func.graph_post_pass.nodes
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
for op in OPS_IN_MODEL_BEFORE:
find_specified_fn(pre_nodes, op)
for op in OPS_IN_MODEL_AFTER:
assert find_specified_fn_maybe(pre_nodes, op) is None
backend_no_func.check_before_ops(model.ops_in_model_before())
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
for op in OPS_IN_MODEL_AFTER:
find_specified_fn(post_nodes, op)
for op in OPS_IN_MODEL_BEFORE:
assert find_specified_fn_maybe(post_nodes, op) is None
backend_no_func.check_after_ops(model.ops_in_model_after())
# check if the functionalization pass is applied
for op in OPS_IN_MODEL:
for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
@@ -183,7 +172,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in OPS_IN_MODEL:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in OPS_IN_MODEL)
assert all(found[op] for op in model.ops_in_model())