[torch.compile] Compile CustomOp.forward_native for SiluAndMul and QuantFP8 to avoid raw torch ops inside opaque custom ops (#32806)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -222,7 +222,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
custom_ops = []
|
||||
custom_ops = ["none"]
|
||||
if enable_silu_mul_custom_op:
|
||||
custom_ops.append("+silu_and_mul")
|
||||
if enable_quant_fp8_custom_op:
|
||||
@@ -231,6 +231,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
backend="eager", # avoid compilation for SiluAndMul and QuantFP8
|
||||
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ def test_act_and_mul(
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||
if activation == "silu_and_mul":
|
||||
layer = SiluAndMul()
|
||||
layer = SiluAndMul(compile_native=False)
|
||||
fn = torch.ops._C.silu_and_mul
|
||||
if activation == "mul_and_silu":
|
||||
layer = MulAndSilu()
|
||||
|
||||
@@ -336,6 +336,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
quant_key.scale.group_shape,
|
||||
column_major_scales=has_col_major_scales,
|
||||
use_ue8m0=is_e8m0,
|
||||
compile_native=False,
|
||||
)
|
||||
|
||||
def forward_rocm_aiter(
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_cached_compilation_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||
# Examples:
|
||||
@@ -118,10 +117,10 @@ class CustomOp(nn.Module):
|
||||
)
|
||||
return super().__new__(op_cls_to_instantiate)
|
||||
|
||||
def __init__(self, enforce_enable: bool = False):
|
||||
def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
|
||||
super().__init__()
|
||||
self._enforce_enable = enforce_enable
|
||||
self._forward_method = self.dispatch_forward()
|
||||
self._forward_method = self.dispatch_forward(compile_native=compile_native)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
@@ -162,7 +161,7 @@ class CustomOp(nn.Module):
|
||||
# PyTorch-native implementation.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def dispatch_forward(self):
|
||||
def dispatch_forward(self, compile_native: bool):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
compilation_config = get_cached_compilation_config()
|
||||
@@ -180,7 +179,9 @@ class CustomOp(nn.Module):
|
||||
compilation_config.disabled_custom_ops.update([self.__class__.name])
|
||||
|
||||
if not enabled:
|
||||
return self.forward_native
|
||||
# Compile forward_native to avoid eager torch ops if inside
|
||||
# opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
|
||||
return self.maybe_compile(self.forward_native, enable=compile_native)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
return self.forward_hip
|
||||
@@ -195,6 +196,40 @@ class CustomOp(nn.Module):
|
||||
else:
|
||||
return self.forward_cuda
|
||||
|
||||
def maybe_compile(self, fn, *, enable: bool = True):
|
||||
"""
|
||||
Compile fn if compilation enabled.
|
||||
Useful for CustomOp instances called from within a torch custom op,
|
||||
meaning the forward call is hidden from the model-level torch.compile.
|
||||
|
||||
NOTE: this does not enable fusion across ops, so opaque custom ops
|
||||
should still be unwrapped wherever possible.
|
||||
"""
|
||||
# Do not compile if compilation disabled
|
||||
from vllm.config.compilation import CompilationMode
|
||||
|
||||
if not enable:
|
||||
return fn
|
||||
|
||||
# Do not compile if global compilation disabled
|
||||
compilation_config = get_cached_compilation_config()
|
||||
if compilation_config.mode == CompilationMode.NONE:
|
||||
return fn
|
||||
|
||||
# If eager backend is used, do not compile either
|
||||
if compilation_config.backend == "eager":
|
||||
return fn
|
||||
|
||||
# dynamic=True to avoid recompilations
|
||||
return torch.compile(
|
||||
fn,
|
||||
dynamic=True,
|
||||
backend=current_platform.simple_compile_backend,
|
||||
options=maybe_disable_graph_partition(
|
||||
current_platform.simple_compile_backend
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
# if no name, then it was not registered
|
||||
|
||||
@@ -75,8 +75,8 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
# --8<-- [end:silu_and_mul]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, *, compile_native: bool = True):
|
||||
super().__init__(compile_native=compile_native)
|
||||
if current_platform.is_cuda_alike():
|
||||
self.op = torch.ops._C.silu_and_mul
|
||||
elif current_platform.is_xpu():
|
||||
|
||||
@@ -38,6 +38,7 @@ class QuantFP8(CustomOp):
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
use_ue8m0: bool | None = None, # for Torch compile
|
||||
compile_native: bool = True,
|
||||
):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
@@ -49,8 +50,9 @@ class QuantFP8(CustomOp):
|
||||
TMA-aligned layout
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
column major format
|
||||
:param compile_native: Manually compile forward_native if compile mode > None
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(compile_native=compile_native)
|
||||
self.static = static
|
||||
self.group_shape = group_shape
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
|
||||
@@ -129,7 +129,7 @@ class ApplyRotaryEmb(CustomOp):
|
||||
is_neox_style: bool = True,
|
||||
enable_fp32_compute: bool = False,
|
||||
) -> None:
|
||||
super().__init__(enforce_enable)
|
||||
super().__init__(enforce_enable=enforce_enable)
|
||||
self.is_neox_style = is_neox_style
|
||||
self.enable_fp32_compute = enable_fp32_compute
|
||||
|
||||
|
||||
Reference in New Issue
Block a user