[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -92,13 +93,19 @@ class RMSNormQuantPattern:
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
|
||||
self.QUANT_OP = QUANT_OPS[key.quant]
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(key.quant)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
@@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale
|
||||
)
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
@@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale
|
||||
)
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
# input, weight, residual
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
@@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
|
||||
)
|
||||
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return at2[1], at2[2]
|
||||
return self.quant_matcher(result_rms)
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
|
||||
)
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
# result, residual, scale
|
||||
return at1[1], at[2], at1[2]
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
|
||||
Reference in New Issue
Block a user