[torch.compile] Compile CustomOp.forward_native for SiluAndMul and QuantFP8 to avoid raw torch ops inside opaque custom ops (#32806)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Luka Govedič
2026-01-22 22:52:26 -05:00
committed by GitHub
parent f61c9da711
commit 5e4e0e51f4
7 changed files with 52 additions and 13 deletions

View File

@@ -222,7 +222,7 @@ def test_fusion_silu_and_mul_quant(
x = torch.rand(num_tokens, hidden_size * 2)
# Reshape pass is needed for the fusion pass to work
custom_ops = []
custom_ops = ["none"]
if enable_silu_mul_custom_op:
custom_ops.append("+silu_and_mul")
if enable_quant_fp8_custom_op:
@@ -231,6 +231,7 @@ def test_fusion_silu_and_mul_quant(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
backend="eager", # avoid compilation for SiluAndMul and QuantFP8
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
),
)

View File

@@ -57,7 +57,7 @@ def test_act_and_mul(
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu_and_mul":
layer = SiluAndMul()
layer = SiluAndMul(compile_native=False)
fn = torch.ops._C.silu_and_mul
if activation == "mul_and_silu":
layer = MulAndSilu()

View File

@@ -336,6 +336,7 @@ class MatcherQuantFP8(MatcherCustomOp):
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
compile_native=False,
)
def forward_rocm_aiter(

View File

@@ -1,16 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import get_cached_compilation_config
from vllm.logger import init_logger
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
@@ -118,10 +117,10 @@ class CustomOp(nn.Module):
)
return super().__new__(op_cls_to_instantiate)
def __init__(self, enforce_enable: bool = False):
def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
super().__init__()
self._enforce_enable = enforce_enable
self._forward_method = self.dispatch_forward()
self._forward_method = self.dispatch_forward(compile_native=compile_native)
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
@@ -162,7 +161,7 @@ class CustomOp(nn.Module):
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
def dispatch_forward(self, compile_native: bool):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
compilation_config = get_cached_compilation_config()
@@ -180,7 +179,9 @@ class CustomOp(nn.Module):
compilation_config.disabled_custom_ops.update([self.__class__.name])
if not enabled:
return self.forward_native
# Compile forward_native to avoid eager torch ops if inside
# opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
return self.maybe_compile(self.forward_native, enable=compile_native)
if current_platform.is_rocm():
return self.forward_hip
@@ -195,6 +196,40 @@ class CustomOp(nn.Module):
else:
return self.forward_cuda
def maybe_compile(self, fn, *, enable: bool = True):
"""
Compile fn if compilation enabled.
Useful for CustomOp instances called from within a torch custom op,
meaning the forward call is hidden from the model-level torch.compile.
NOTE: this does not enable fusion across ops, so opaque custom ops
should still be unwrapped wherever possible.
"""
# Do not compile if compilation disabled
from vllm.config.compilation import CompilationMode
if not enable:
return fn
# Do not compile if global compilation disabled
compilation_config = get_cached_compilation_config()
if compilation_config.mode == CompilationMode.NONE:
return fn
# If eager backend is used, do not compile either
if compilation_config.backend == "eager":
return fn
# dynamic=True to avoid recompilations
return torch.compile(
fn,
dynamic=True,
backend=current_platform.simple_compile_backend,
options=maybe_disable_graph_partition(
current_platform.simple_compile_backend
),
)
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered

View File

@@ -75,8 +75,8 @@ class SiluAndMul(CustomOp):
# --8<-- [end:silu_and_mul]
def __init__(self):
super().__init__()
def __init__(self, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
if current_platform.is_cuda_alike():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():

View File

@@ -38,6 +38,7 @@ class QuantFP8(CustomOp):
column_major_scales: bool = False,
tma_aligned_scales: bool = False,
use_ue8m0: bool | None = None, # for Torch compile
compile_native: bool = True,
):
"""
:param static: static or dynamic quantization
@@ -49,8 +50,9 @@ class QuantFP8(CustomOp):
TMA-aligned layout
:param column_major_scales: For group quantization, output scales in
column major format
:param compile_native: Manually compile forward_native if compile mode > None
"""
super().__init__()
super().__init__(compile_native=compile_native)
self.static = static
self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN

View File

@@ -129,7 +129,7 @@ class ApplyRotaryEmb(CustomOp):
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> None:
super().__init__(enforce_enable)
super().__init__(enforce_enable=enforce_enable)
self.is_neox_style = is_neox_style
self.enable_fp32_compute = enable_fp32_compute