[Feature]Add async tensor parallelism using compilation pass (#17882)

Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
cascade
2025-05-23 01:03:34 -07:00
committed by GitHub
parent 4c611348a7
commit 71ea614d4a
11 changed files with 472 additions and 56 deletions

View File

@@ -243,24 +243,25 @@ class SequenceParallelismPass(VllmInductorPass):
pass_name="sequence_parallelism_pass")
for epsilon in [1e-5, 1e-6]:
EmbeddingAllReduceRMSNormPattern(
epsilon, self.dtype, self.device).register(self.patterns)
epsilon, self.model_dtype, self.device).register(self.patterns)
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
LastAllReduceRMSNormPattern(epsilon, self.dtype,
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")
self.end_and_log()