Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,8 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -23,12 +22,14 @@ logger = init_logger(__name__)
class _RMSNormAndQuantOpHelper:
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
def __init__(self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: Optional[torch._ops.OpOverload] = None,
**kwargs):
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: Optional[torch._ops.OpOverload] = None,
**kwargs,
):
self.epsilon = epsilon
self.dtype = dtype
self.device = device
@@ -40,60 +41,78 @@ class _RMSNormAndQuantOpHelper:
result=result_buffer,
input=input_tensor,
weight=weight_tensor,
epsilon=self.epsilon)
epsilon=self.epsilon,
)
def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor,
weight_tensor):
def _functional_fused_add_rmsnorm(
self, input_tensor, residual_tensor, weight_tensor
):
return torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=input_tensor,
residual=residual_tensor,
weight=weight_tensor,
epsilon=self.epsilon)
epsilon=self.epsilon,
)
def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer,
quant_result_buffer, input_tensor,
weight_tensor, scale_tensor):
def _functional_rmsnorm_then_quant(
self,
rmsnorm_result_buffer,
quant_result_buffer,
input_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer,
input_tensor,
weight_tensor)
rmsnorm_out_tuple = self._functional_rmsnorm(
rmsnorm_result_buffer, input_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=rmsnorm_out_tuple[1],
scale=scale_tensor)
scale=scale_tensor,
)
return quant_out_tuple
def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer,
input_tensor, residual_tensor,
weight_tensor, scale_tensor):
def _functional_fused_add_rmsnorm_then_quant(
self,
quant_result_buffer,
input_tensor,
residual_tensor,
weight_tensor,
scale_tensor,
):
if self.quant_op is None:
raise RuntimeError(
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
)
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
input_tensor, residual_tensor, weight_tensor)
input_tensor, residual_tensor, weight_tensor
)
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
self.quant_op,
result=quant_result_buffer,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale_tensor)
scale=scale_tensor,
)
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
"""Helper for sequence parallelism patterns."""
def __init__(self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: Optional[torch._ops.OpOverload] = None,
**kwargs):
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
quant_op: Optional[torch._ops.OpOverload] = None,
**kwargs,
):
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
@@ -103,21 +122,16 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.reduce_scatter.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp_group.unique_name)
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp_group.unique_name)
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
@@ -126,7 +140,6 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
return [input, permute, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
permute: torch.Tensor,
@@ -145,26 +158,23 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(reduce_scatter)
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter,
arg3_1)
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
all_gather = self._all_gather(rmsnorm[1])
return all_gather, reduce_scatter
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
@@ -173,7 +183,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
@@ -181,7 +190,8 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights)
all_reduce, residual, rms_norm_weights
)
return rmsnorm[1], rmsnorm[2]
def replacement(
@@ -191,23 +201,22 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights)
reduce_scatter, residual, rms_norm_weights
)
all_gather = self._all_gather(rmsnorm[1])
return all_gather, rmsnorm[2]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
@@ -216,7 +225,6 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
@@ -224,7 +232,8 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights)
all_reduce, residual, rms_norm_weights
)
return rmsnorm[1]
def replacement(
@@ -234,37 +243,34 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights)
reduce_scatter, residual, rms_norm_weights
)
normalized = self._all_gather(rmsnorm[1])
return normalized
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
op: torch._ops.OpOverload):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty([1, 8, 4],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty([1, 8, 4],
device=self.device,
dtype=FP8_DTYPE)
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
@@ -274,7 +280,8 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
):
all_reduce = self._all_reduce(input)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, all_reduce, weight, scale)
rmsnorm_result, quant_result, all_reduce, weight, scale
)
return static_fp8[1], all_reduce
def replacement(
@@ -286,34 +293,36 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
):
reduce_scatter = self._reduce_scatter(input)
rmsnorm_result = torch.empty_like(reduce_scatter,
dtype=rmsnorm_result.dtype)
rmsnorm_result = torch.empty_like(
reduce_scatter, dtype=rmsnorm_result.dtype
)
quant_result = torch.empty_like(
rmsnorm_result, # Output of RMSNorm
dtype=quant_result.dtype)
dtype=quant_result.dtype,
)
static_fp8 = self._functional_rmsnorm_then_quant(
rmsnorm_result, quant_result, reduce_scatter, weight, scale)
rmsnorm_result, quant_result, reduce_scatter, weight, scale
)
all_gather = self._all_gather(static_fp8[1])
return all_gather, reduce_scatter
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
op: torch._ops.OpOverload):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
@@ -326,7 +335,6 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
@@ -335,8 +343,11 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
result, all_reduce, residual, rms_norm_weights, scale)
static_fp8, rmsnorm_residual_out = (
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
result, all_reduce, residual, rms_norm_weights, scale
)
)
return static_fp8[1], rmsnorm_residual_out
def replacement(
@@ -347,31 +358,31 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter,
dtype=result.dtype)
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
scale)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
static_fp8, rmsnorm_residual_out = (
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
)
all_gather = self._all_gather(static_fp8[1])
return all_gather, rmsnorm_residual_out
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
op: torch._ops.OpOverload):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
):
super().__init__(epsilon, dtype, device, quant_op=op)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
@@ -384,7 +395,6 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
residual: torch.Tensor,
@@ -394,7 +404,8 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
result, all_reduce, residual, rms_norm_weights, scale)
result, all_reduce, residual, rms_norm_weights, scale
)
return static_fp8[1]
def replacement(
@@ -405,16 +416,16 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
quant_result_buf = torch.empty_like(reduce_scatter,
dtype=result.dtype)
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
scale)
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
)
normalized = self._all_gather(static_fp8[1])
return normalized
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class SequenceParallelismPass(VllmPatternMatcherPass):
@@ -442,30 +453,34 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass")
pass_name="sequence_parallelism_pass"
)
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
epsilon, self.model_dtype, self.device, fp8_quant_op
).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
epsilon, self.model_dtype, self.device, fp8_quant_op
).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
epsilon, self.model_dtype, self.device, fp8_quant_op
).register(self.patterns)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
FirstAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
MiddleAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
LastAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool: