[torch.compile] Compile CustomOp.forward_native for SiluAndMul and QuantFP8 to avoid raw torch ops inside opaque custom ops (#32806)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -222,7 +222,7 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
x = torch.rand(num_tokens, hidden_size * 2)
|
x = torch.rand(num_tokens, hidden_size * 2)
|
||||||
|
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
custom_ops = []
|
custom_ops = ["none"]
|
||||||
if enable_silu_mul_custom_op:
|
if enable_silu_mul_custom_op:
|
||||||
custom_ops.append("+silu_and_mul")
|
custom_ops.append("+silu_and_mul")
|
||||||
if enable_quant_fp8_custom_op:
|
if enable_quant_fp8_custom_op:
|
||||||
@@ -231,6 +231,7 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
custom_ops=custom_ops,
|
custom_ops=custom_ops,
|
||||||
|
backend="eager", # avoid compilation for SiluAndMul and QuantFP8
|
||||||
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
|
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def test_act_and_mul(
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||||
if activation == "silu_and_mul":
|
if activation == "silu_and_mul":
|
||||||
layer = SiluAndMul()
|
layer = SiluAndMul(compile_native=False)
|
||||||
fn = torch.ops._C.silu_and_mul
|
fn = torch.ops._C.silu_and_mul
|
||||||
if activation == "mul_and_silu":
|
if activation == "mul_and_silu":
|
||||||
layer = MulAndSilu()
|
layer = MulAndSilu()
|
||||||
|
|||||||
@@ -336,6 +336,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
|||||||
quant_key.scale.group_shape,
|
quant_key.scale.group_shape,
|
||||||
column_major_scales=has_col_major_scales,
|
column_major_scales=has_col_major_scales,
|
||||||
use_ue8m0=is_e8m0,
|
use_ue8m0=is_e8m0,
|
||||||
|
compile_native=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_rocm_aiter(
|
def forward_rocm_aiter(
|
||||||
|
|||||||
@@ -1,16 +1,15 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import get_cached_compilation_config
|
from vllm.config import get_cached_compilation_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||||
# To check if an op with a name is enabled, call .enabled() on the class.
|
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||||
# Examples:
|
# Examples:
|
||||||
@@ -118,10 +117,10 @@ class CustomOp(nn.Module):
|
|||||||
)
|
)
|
||||||
return super().__new__(op_cls_to_instantiate)
|
return super().__new__(op_cls_to_instantiate)
|
||||||
|
|
||||||
def __init__(self, enforce_enable: bool = False):
|
def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._enforce_enable = enforce_enable
|
self._enforce_enable = enforce_enable
|
||||||
self._forward_method = self.dispatch_forward()
|
self._forward_method = self.dispatch_forward(compile_native=compile_native)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self._forward_method(*args, **kwargs)
|
return self._forward_method(*args, **kwargs)
|
||||||
@@ -162,7 +161,7 @@ class CustomOp(nn.Module):
|
|||||||
# PyTorch-native implementation.
|
# PyTorch-native implementation.
|
||||||
return self.forward_native(*args, **kwargs)
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
def dispatch_forward(self):
|
def dispatch_forward(self, compile_native: bool):
|
||||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||||
# specific backend. Currently, we do not support dynamic dispatching.
|
# specific backend. Currently, we do not support dynamic dispatching.
|
||||||
compilation_config = get_cached_compilation_config()
|
compilation_config = get_cached_compilation_config()
|
||||||
@@ -180,7 +179,9 @@ class CustomOp(nn.Module):
|
|||||||
compilation_config.disabled_custom_ops.update([self.__class__.name])
|
compilation_config.disabled_custom_ops.update([self.__class__.name])
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return self.forward_native
|
# Compile forward_native to avoid eager torch ops if inside
|
||||||
|
# opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
|
||||||
|
return self.maybe_compile(self.forward_native, enable=compile_native)
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
return self.forward_hip
|
return self.forward_hip
|
||||||
@@ -195,6 +196,40 @@ class CustomOp(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.forward_cuda
|
return self.forward_cuda
|
||||||
|
|
||||||
|
def maybe_compile(self, fn, *, enable: bool = True):
|
||||||
|
"""
|
||||||
|
Compile fn if compilation enabled.
|
||||||
|
Useful for CustomOp instances called from within a torch custom op,
|
||||||
|
meaning the forward call is hidden from the model-level torch.compile.
|
||||||
|
|
||||||
|
NOTE: this does not enable fusion across ops, so opaque custom ops
|
||||||
|
should still be unwrapped wherever possible.
|
||||||
|
"""
|
||||||
|
# Do not compile if compilation disabled
|
||||||
|
from vllm.config.compilation import CompilationMode
|
||||||
|
|
||||||
|
if not enable:
|
||||||
|
return fn
|
||||||
|
|
||||||
|
# Do not compile if global compilation disabled
|
||||||
|
compilation_config = get_cached_compilation_config()
|
||||||
|
if compilation_config.mode == CompilationMode.NONE:
|
||||||
|
return fn
|
||||||
|
|
||||||
|
# If eager backend is used, do not compile either
|
||||||
|
if compilation_config.backend == "eager":
|
||||||
|
return fn
|
||||||
|
|
||||||
|
# dynamic=True to avoid recompilations
|
||||||
|
return torch.compile(
|
||||||
|
fn,
|
||||||
|
dynamic=True,
|
||||||
|
backend=current_platform.simple_compile_backend,
|
||||||
|
options=maybe_disable_graph_partition(
|
||||||
|
current_platform.simple_compile_backend
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
# if no name, then it was not registered
|
# if no name, then it was not registered
|
||||||
|
|||||||
@@ -75,8 +75,8 @@ class SiluAndMul(CustomOp):
|
|||||||
|
|
||||||
# --8<-- [end:silu_and_mul]
|
# --8<-- [end:silu_and_mul]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, *, compile_native: bool = True):
|
||||||
super().__init__()
|
super().__init__(compile_native=compile_native)
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
self.op = torch.ops._C.silu_and_mul
|
self.op = torch.ops._C.silu_and_mul
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class QuantFP8(CustomOp):
|
|||||||
column_major_scales: bool = False,
|
column_major_scales: bool = False,
|
||||||
tma_aligned_scales: bool = False,
|
tma_aligned_scales: bool = False,
|
||||||
use_ue8m0: bool | None = None, # for Torch compile
|
use_ue8m0: bool | None = None, # for Torch compile
|
||||||
|
compile_native: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param static: static or dynamic quantization
|
:param static: static or dynamic quantization
|
||||||
@@ -49,8 +50,9 @@ class QuantFP8(CustomOp):
|
|||||||
TMA-aligned layout
|
TMA-aligned layout
|
||||||
:param column_major_scales: For group quantization, output scales in
|
:param column_major_scales: For group quantization, output scales in
|
||||||
column major format
|
column major format
|
||||||
|
:param compile_native: Manually compile forward_native if compile mode > None
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(compile_native=compile_native)
|
||||||
self.static = static
|
self.static = static
|
||||||
self.group_shape = group_shape
|
self.group_shape = group_shape
|
||||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ class ApplyRotaryEmb(CustomOp):
|
|||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
enable_fp32_compute: bool = False,
|
enable_fp32_compute: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(enforce_enable)
|
super().__init__(enforce_enable=enforce_enable)
|
||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
self.enable_fp32_compute = enable_fp32_compute
|
self.enable_fp32_compute = enable_fp32_compute
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user