[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2025-10-17 10:10:23 -04:00
committed by GitHub
parent be429d0cfd
commit bd7157a071
28 changed files with 1519 additions and 721 deletions

View File

@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@@ -92,13 +93,19 @@ class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
self.QUANT_OP = QUANT_OPS[key.quant]
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
@@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
# result
return at2[1]
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
@@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale
)
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
# result, residual
return at1[1], at[2]
return result, residual
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight, residual
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pm.register_replacement(
@@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return at2[1], at2[2]
return self.quant_matcher(result_rms)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
def replacement(input: torch.Tensor, weight: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, scale
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
@@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
# result, residual, scale
return at1[1], at[2], at1[2]
return result, residual, scale
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, residual, scale
return at[1], at[3], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
@@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
pass_name="rmsnorm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log