[compile] Enable sequence parallelism matching w/o custom ops enabled (#27126)
Signed-off-by: angelayi <yiangela7@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: ProExpertProg <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
@@ -10,98 +12,28 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _RMSNormAndQuantOpHelper:
|
||||
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
|
||||
def get_first_out_wrapper(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args):
|
||||
return fn(*args)[0]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: torch._ops.OpOverload | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.quant_op = quant_op
|
||||
|
||||
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=result_buffer,
|
||||
input=input_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def _functional_fused_add_rmsnorm(
|
||||
self, input_tensor, residual_tensor, weight_tensor
|
||||
):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input_tensor,
|
||||
residual=residual_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def _functional_rmsnorm_then_quant(
|
||||
self,
|
||||
rmsnorm_result_buffer,
|
||||
quant_result_buffer,
|
||||
input_tensor,
|
||||
weight_tensor,
|
||||
scale_tensor,
|
||||
):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
rmsnorm_out_tuple = self._functional_rmsnorm(
|
||||
rmsnorm_result_buffer, input_tensor, weight_tensor
|
||||
)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor,
|
||||
)
|
||||
return quant_out_tuple
|
||||
|
||||
def _functional_fused_add_rmsnorm_then_quant(
|
||||
self,
|
||||
quant_result_buffer,
|
||||
input_tensor,
|
||||
residual_tensor,
|
||||
weight_tensor,
|
||||
scale_tensor,
|
||||
):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
|
||||
input_tensor, residual_tensor, weight_tensor
|
||||
)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor,
|
||||
)
|
||||
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
|
||||
return wrapper
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
@@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: torch._ops.OpOverload | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
@@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, permute, arg3_1]
|
||||
return [input, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm[1], all_reduce
|
||||
return rmsnorm, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
|
||||
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
all_gather = self._all_gather(rmsnorm)
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
@@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
@@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights
|
||||
)
|
||||
return rmsnorm[1], rmsnorm[2]
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
|
||||
return rmsnorm[0], rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights
|
||||
)
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
return all_gather, rmsnorm[2]
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
|
||||
all_gather = self._all_gather(rmsnorm[0])
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, rmsnorm[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights
|
||||
)
|
||||
return rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights
|
||||
)
|
||||
normalized = self._all_gather(rmsnorm[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
return [input, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, all_reduce, weight, scale
|
||||
)
|
||||
return static_fp8[1], all_reduce
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(
|
||||
reduce_scatter, dtype=rmsnorm_result.dtype
|
||||
)
|
||||
quant_result = torch.empty_like(
|
||||
rmsnorm_result, # Output of RMSNorm
|
||||
dtype=quant_result.dtype,
|
||||
)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale
|
||||
)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
@@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
return [residual, mm_1, rms_norm_weights, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, rmsnorm_residual_out = (
|
||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
result, all_reduce, residual, rms_norm_weights, scale
|
||||
)
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
all_reduce, rms_norm_weights, residual
|
||||
)
|
||||
return static_fp8[1], rmsnorm_residual_out
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, residual_out
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# add a temporary slice which will become a noop
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
||||
static_fp8, rmsnorm_residual_out = (
|
||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
||||
)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
reduce_scatter, rms_norm_weights, residual
|
||||
)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
return all_gather, rmsnorm_residual_out
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, residual_out
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
result, all_reduce, residual, rms_norm_weights, scale
|
||||
)
|
||||
return static_fp8[1]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
||||
)
|
||||
normalized = self._all_gather(static_fp8[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
@@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
|
||||
|
||||
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||
Because the pattern matcher starts at the end of the graph, the replacement
|
||||
contains a slice that temporarily conforms the input residual to the correct size.
|
||||
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||
what have now become no-op slices.
|
||||
|
||||
Note that an older version of the pass did not need this as it operated only on
|
||||
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||
mismatched shapes during replacement. So this approach has the same assumption that
|
||||
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||
|
||||
Correctness-wise, this is approach strictly better than before - before,
|
||||
the graph was incorrect semantically and shape-wise during the pass.
|
||||
With this approach there's only semantic incorrectness during the pass.
|
||||
Both approaches restore a correct graph once all patterns are matched.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# Used to cleanup redundant views created temporarily
|
||||
# to circumvent residual shape change issues
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
).register(self.patterns)
|
||||
LastAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
@@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable(self, shape: int | None) -> bool:
|
||||
@@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
# Clean up reshape nodes
|
||||
self.noop_cleanup(graph)
|
||||
|
||||
Reference in New Issue
Block a user