[Feature]Add async tensor parallelism using compilation pass (#17882)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@@ -243,24 +243,25 @@ class SequenceParallelismPass(VllmInductorPass):
|
||||
pass_name="sequence_parallelism_pass")
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
EmbeddingAllReduceRMSNormPattern(
|
||||
epsilon, self.dtype, self.device).register(self.patterns)
|
||||
epsilon, self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
# only do replace for specific shapes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
||||
self.end_and_log()
|
||||
|
||||
Reference in New Issue
Block a user