[Feature]Add async tensor parallelism using compilation pass (#17882)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@@ -5,9 +5,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
|
||||
find_specified_fn,
|
||||
find_specified_fn_maybe, is_func)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
PassConfig, VllmConfig)
|
||||
@@ -21,17 +19,6 @@ from vllm.utils import update_environment_variables
|
||||
from ..utils import multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
OPS_IN_MODEL_BEFORE = [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
]
|
||||
|
||||
OPS_IN_MODEL_AFTER = [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default,
|
||||
]
|
||||
|
||||
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@@ -78,6 +65,18 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
return norm_output, residual_output
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
return [torch.ops._C.fused_add_rms_norm.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@@ -156,26 +155,16 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend_no_func.graph_pre_pass.nodes
|
||||
post_nodes = backend_no_func.graph_post_pass.nodes
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
for op in OPS_IN_MODEL_BEFORE:
|
||||
find_specified_fn(pre_nodes, op)
|
||||
for op in OPS_IN_MODEL_AFTER:
|
||||
assert find_specified_fn_maybe(pre_nodes, op) is None
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
for op in OPS_IN_MODEL_AFTER:
|
||||
find_specified_fn(post_nodes, op)
|
||||
for op in OPS_IN_MODEL_BEFORE:
|
||||
assert find_specified_fn_maybe(post_nodes, op) is None
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in OPS_IN_MODEL:
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
||||
op) is None # noqa: E501
|
||||
@@ -183,7 +172,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in OPS_IN_MODEL:
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in OPS_IN_MODEL)
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
|
||||
Reference in New Issue
Block a user