[compile] Enable sequence parallelism matching w/o custom ops enabled (#27126)

Signed-off-by: angelayi <yiangela7@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
Angela Yi
2025-11-15 03:46:12 -08:00
committed by GitHub
parent 173b356abf
commit f36292dbee
6 changed files with 472 additions and 444 deletions

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
@@ -10,98 +12,28 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .noop_elimination import NoOpEliminationPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
class _RMSNormAndQuantOpHelper:
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
def get_first_out_wrapper(fn):
@functools.wraps(fn)
def wrapper(*args):
return fn(*args)[0]
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: torch._ops.OpOverload | None = None,
**kwargs,
):
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.quant_op = quant_op
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=result_buffer,
input=input_tensor,
weight=weight_tensor,
epsilon=self.epsilon,
)
def _functional_fused_add_rmsnorm(
self, input_tensor, residual_tensor, weight_tensor
):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=input_tensor,
residual=residual_tensor,
weight=weight_tensor,
epsilon=self.epsilon,
)
def _functional_rmsnorm_then_quant(
self,
rmsnorm_result_buffer,
quant_result_buffer,
input_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
rmsnorm_out_tuple = self._functional_rmsnorm(
rmsnorm_result_buffer, input_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=rmsnorm_out_tuple[1],
scale=scale_tensor,
)
return quant_out_tuple
def _functional_fused_add_rmsnorm_then_quant(
self,
quant_result_buffer,
input_tensor,
residual_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
input_tensor, residual_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale_tensor,
)
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
return wrapper
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
def __init__(
@@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: torch._ops.OpOverload | None = None,
**kwargs,
):
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
self.epsilon = epsilon
self.dtype = dtype
self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
@@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, permute, arg3_1]
return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
all_reduce = self._all_reduce(input)
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
return rmsnorm[1], all_reduce
return rmsnorm, all_reduce
def replacement(
input: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(reduce_scatter)
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm[1])
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter
pm.register_replacement(
@@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)
return rmsnorm[1], rmsnorm[2]
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
return rmsnorm[0], rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)
all_gather = self._all_gather(rmsnorm[1])
return all_gather, rmsnorm[2]
residual = residual[0 : reduce_scatter.size(0), ...]
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
all_gather = self._all_gather(rmsnorm[0])
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)
return rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)
normalized = self._all_gather(rmsnorm[1])
return normalized
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
@@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
self,
epsilon: float,
dtype: torch.dtype,
device: str,
):
super().__init__(epsilon, dtype, device, quant_op=op)
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = self._all_reduce(input)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, all_reduce, weight, scale
)
return static_fp8[1], all_reduce
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(
reduce_scatter, dtype=rmsnorm_result.dtype
)
quant_result = torch.empty_like(
rmsnorm_result, # Output of RMSNorm
dtype=quant_result.dtype,
)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, reduce_scatter, weight, scale
)
all_gather = self._all_gather(static_fp8[1])
rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
return all_gather, reduce_scatter
@@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, rmsnorm_residual_out = (
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
result, all_reduce, residual, rms_norm_weights, scale
)
rms, residual_out = self.rmsnorm_matcher(
all_reduce, rms_norm_weights, residual
)
return static_fp8[1], rmsnorm_residual_out
quant, _ = self.quant_matcher(rms, scale)
return quant, residual_out
def replacement(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pattern matcher replaces from top-to-bottom,
# so residual is still the full size here.
# add a temporary slice which will become a noop
# once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
static_fp8, rmsnorm_residual_out = (
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
residual = residual[0 : reduce_scatter.size(0), ...]
rms, residual_out = self.rmsnorm_matcher(
reduce_scatter, rms_norm_weights, residual
)
all_gather = self._all_gather(static_fp8[1])
return all_gather, rmsnorm_residual_out
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
# shape of residual changes but that's fine,
# next node is already slicing it, now becomes a noop
return all_gather, residual_out
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
result, all_reduce, residual, rms_norm_weights, scale
)
return static_fp8[1]
def replacement(
result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
normalized = self._all_gather(static_fp8[1])
return normalized
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
get_first_out_wrapper(pattern),
get_first_out_wrapper(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
@@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
This pass splits up the residual tensor across TP ranks and hence divides its size.
Because the pattern matcher starts at the end of the graph, the replacement
contains a slice that temporarily conforms the input residual to the correct size.
After all patterns have been matched, we use a NoOpEliminationPass to clean up
what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
mismatched shapes during replacement. So this approach has the same assumption that
correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
# Used to cleanup redundant views created temporarily
# to circumvent residual shape change issues
self.noop_cleanup = NoOpEliminationPass(config)
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass"
)
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
epsilon, self.model_dtype, self.device
).register(self.patterns)
# Normal RMSNorm patterns
@@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
epsilon, self.model_dtype, self.device
).register(self.patterns)
LastAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool:
@@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes
self.noop_cleanup(graph)