Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -12,8 +12,15 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
GroupShape,
QuantKey,
ScaleDesc,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
@@ -40,12 +47,9 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
@@ -57,80 +61,93 @@ class FusedRMSQuantKey(NamedTuple):
quant: type of quantization
fused_add: does the op also perform the residual add
"""
quant: QuantKey
fused_add: bool
def __str__(self):
return (f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)")
return (
f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)"
)
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8StaticTensorSym, False
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8StaticTensorSym, True
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, False
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(
kFp8DynamicTokenSym, True
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
}
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
assert key.quant in QUANT_OPS, \
f"unsupported quantization scheme {key.quant}"
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
self.QUANT_OP = QUANT_OPS[key.quant]
assert key in FUSED_OPS, \
f"unsupported fused rmsnorm+quant op for {key}"
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
fused_key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale)
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
# result
return at2[1]
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result
return at[1]
@@ -140,53 +157,60 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
empty_fp32(1, 1), # scale
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
pm_pass)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(self.QUANT_OP,
result=result,
input=at[1],
scale=scale)
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale
)
# result, residual
return at1[1], at[2]
def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon)
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon,
)
# result, residual
return at[1], at[2]
@@ -196,7 +220,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
empty_fp32(1, 1), # scale
]
pm.register_replacement(
@@ -209,49 +233,59 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
scale=scale,
symmetric=symmetric))
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale,
scale_ub=None)
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
)
# result, scale
return at2[1], at2[2]
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
)
# result, scale
return at[1], at[2]
@@ -261,7 +295,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
empty_fp32(1, 1), # scale
]
pm.register_replacement(
@@ -274,49 +308,59 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
scale=scale,
symmetric=symmetric))
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(self.QUANT_OP,
result=result,
input=at[1],
scale=scale,
scale_ub=None)
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
)
# result, residual, scale
return at1[1], at[2], at1[2]
def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual)
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
)
# result, residual, scale
return at[1], at[3], at[2]
@@ -326,7 +370,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
empty_fp32(1, 1), # scale
]
pm.register_replacement(
@@ -349,24 +393,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass")
pass_name="rmsnorm_quant_fusion_pass"
)
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns)
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns)
self.patterns
)
self.dump_patterns(config, self.patterns)
@@ -376,8 +421,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
return self.hash_source(self, RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern)
return self.hash_source(
self,
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
)