[Perf] Add opt-in SM100 Oink RMSNorm custom-op path (#31828)
Signed-off-by: Laura Wang <3700467+Laurawly@users.noreply.github.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
74
tests/model_executor/test_oink_integration.py
Normal file
74
tests/model_executor/test_oink_integration.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _load_oink_ops_module():
|
||||
# Import the module normally (vllm is installed as an editable package in CI).
|
||||
from vllm import _oink_ops
|
||||
|
||||
return _oink_ops
|
||||
|
||||
|
||||
def test_oink_availability_checks(monkeypatch: pytest.MonkeyPatch):
|
||||
_oink_ops = _load_oink_ops_module()
|
||||
|
||||
# Ensure the ops namespace exists and is mutable for tests.
|
||||
monkeypatch.setattr(
|
||||
torch.ops,
|
||||
"oink",
|
||||
types.SimpleNamespace(rmsnorm=lambda x, w, eps: x),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
# Case 1: CUDA not available.
|
||||
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
|
||||
assert _oink_ops.is_oink_available_for_device(0) is False
|
||||
|
||||
# Case 2: CUDA available but < SM100.
|
||||
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda idx: (9, 0))
|
||||
assert _oink_ops.is_oink_available_for_device(0) is False
|
||||
|
||||
# Case 3: CUDA available and SM100, rmsnorm op registered.
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda idx: (10, 0))
|
||||
assert _oink_ops.is_oink_available_for_device(0) is True
|
||||
|
||||
# fused op presence probe
|
||||
assert _oink_ops.has_fused_add_rms_norm() is False
|
||||
monkeypatch.setattr(
|
||||
torch.ops,
|
||||
"oink",
|
||||
types.SimpleNamespace(
|
||||
rmsnorm=lambda x, w, eps: x,
|
||||
fused_add_rms_norm=lambda x, residual, w, eps: None,
|
||||
),
|
||||
raising=False,
|
||||
)
|
||||
assert _oink_ops.has_fused_add_rms_norm() is True
|
||||
|
||||
|
||||
def test_can_view_as_2d_stride_guard():
|
||||
# Import the helper from the layernorm module.
|
||||
from vllm.model_executor.layers.layernorm import _can_view_as_2d
|
||||
|
||||
x = torch.zeros((2, 3, 4))
|
||||
assert _can_view_as_2d(x) is True
|
||||
|
||||
# Size-1 dims should be ignored by the viewability check.
|
||||
# Create a tensor where stride(0) != stride(1) * size(1) due to padding,
|
||||
# but view(-1, H) is still valid because dim 1 has size 1.
|
||||
base = torch.zeros((2, 10, 4))
|
||||
x_singleton = base[:, :1, :]
|
||||
x_singleton.view(-1, x_singleton.shape[-1])
|
||||
assert _can_view_as_2d(x_singleton) is True
|
||||
|
||||
# Middle-dimension stride break: view(-1, hidden) should be invalid.
|
||||
x2 = x[:, ::2, :]
|
||||
with pytest.raises(RuntimeError):
|
||||
x2.view(-1, x2.shape[-1])
|
||||
assert _can_view_as_2d(x2) is False
|
||||
96
vllm/_oink_ops.py
Normal file
96
vllm/_oink_ops.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Small helper wrappers for external Oink Blackwell custom ops.
|
||||
|
||||
vLLM does not depend on the external Oink repository/package. When an external
|
||||
plugin registers torch.library.custom_op entrypoints under the `oink::`
|
||||
namespace (e.g. via vLLM's general_plugins mechanism) and
|
||||
`VLLM_USE_OINK_OPS=1` is set, vLLM can route eligible calls to those ops.
|
||||
|
||||
This module provides:
|
||||
- A single place to probe Oink op availability at module init time
|
||||
(outside torch.compile tracing), and
|
||||
- Thin wrappers around the torch.ops entrypoints for use in CUDA fast paths,
|
||||
without introducing graph breaks.
|
||||
|
||||
Important:
|
||||
Do not call the availability helpers in a compiled region. They may call
|
||||
functions decorated with `torch._dynamo.disable` to safely check
|
||||
conditions that should not be traced.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch._dynamo import disable as _dynamo_disable # type: ignore[attr-defined]
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
def _dynamo_disable(fn: Callable): # type: ignore[misc]
|
||||
return fn
|
||||
|
||||
|
||||
def _has_oink_op(op_name: str) -> bool:
|
||||
"""Check if a specific oink op is registered."""
|
||||
return hasattr(torch.ops, "oink") and hasattr(torch.ops.oink, op_name)
|
||||
|
||||
|
||||
@_dynamo_disable
|
||||
def is_oink_available_for_device(device_index: int) -> bool:
|
||||
"""Return True if Oink ops are registered and device is SM100+.
|
||||
|
||||
This function is intended to be called during module initialization
|
||||
(e.g., in RMSNorm.__init__), not in the forward path.
|
||||
|
||||
External plugins are expected to gate registration on SM100+ and
|
||||
VLLM_USE_OINK_OPS=1, so if the ops are present they should be usable.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
major, minor = torch.cuda.get_device_capability(device_index)
|
||||
sm = 10 * major + minor
|
||||
if sm < 100:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return _has_oink_op("rmsnorm")
|
||||
|
||||
|
||||
def has_fused_add_rms_norm() -> bool:
|
||||
"""Return True if the in-place fused op is registered."""
|
||||
return _has_oink_op("fused_add_rms_norm")
|
||||
|
||||
|
||||
def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Call `torch.ops.oink.rmsnorm`.
|
||||
|
||||
This wrapper is safe to call in torch.compile regions.
|
||||
"""
|
||||
return torch.ops.oink.rmsnorm(x, weight, eps)
|
||||
|
||||
|
||||
def fused_add_rms_norm_(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
) -> None:
|
||||
"""Call `torch.ops.oink.fused_add_rms_norm` (mutates x and residual)."""
|
||||
torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps)
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convenience wrapper returning (x, residual) after in-place mutation."""
|
||||
fused_add_rms_norm_(x, residual, weight, eps)
|
||||
return x, residual
|
||||
@@ -97,6 +97,7 @@ if TYPE_CHECKING:
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_DISABLED_KERNELS: list[str] = []
|
||||
VLLM_DISABLE_PYNCCL: bool = False
|
||||
VLLM_USE_OINK_OPS: bool = False
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
|
||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||
@@ -896,6 +897,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DISABLE_PYNCCL": lambda: (
|
||||
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Optional: enable external Oink custom ops (e.g., Blackwell RMSNorm).
|
||||
# Disabled by default.
|
||||
"VLLM_USE_OINK_OPS": lambda: (
|
||||
os.getenv("VLLM_USE_OINK_OPS", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Disable aiter ops unless specifically enabled.
|
||||
# Acts as a parent switch to enable the rest of the other operations.
|
||||
"VLLM_ROCM_USE_AITER": lambda: (
|
||||
|
||||
@@ -6,7 +6,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _oink_ops, envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
@@ -14,6 +16,41 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _can_view_as_2d(x: torch.Tensor) -> bool:
|
||||
"""Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
|
||||
if x.dim() < 2:
|
||||
return False
|
||||
if x.dim() == 2:
|
||||
return True
|
||||
# For a view(-1, N) to be valid, all leading dims must be contiguous with
|
||||
# respect to each other (size-1 dims are ignored).
|
||||
for dim in range(x.dim() - 1):
|
||||
# Strides for size-1 dims are irrelevant and can be arbitrary.
|
||||
if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
|
||||
dim + 1
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
|
||||
"""Return True if x_2d meets Oink's pointer-path stride constraints."""
|
||||
if x_2d.dim() != 2:
|
||||
return False
|
||||
if x_2d.stride(1) != 1:
|
||||
return False
|
||||
# Match Oink's vectorization constraint: stride(0) divisible by 256b.
|
||||
if x_2d.dtype in (torch.float16, torch.bfloat16):
|
||||
divby = 16
|
||||
elif x_2d.dtype == torch.float32:
|
||||
divby = 8
|
||||
else:
|
||||
return False
|
||||
return (x_2d.stride(0) % divby) == 0
|
||||
|
||||
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
@@ -131,6 +168,57 @@ class RMSNorm(CustomOp):
|
||||
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
||||
)
|
||||
|
||||
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on
|
||||
# compatible CUDA devices (e.g., SM100) when the external Oink
|
||||
# package is available. This is detected once at construction time
|
||||
# to avoid per-call device queries in the hot path.
|
||||
self._use_oink_rmsnorm = False
|
||||
self._use_oink_fused_add_rmsnorm = False
|
||||
if (
|
||||
not current_platform.is_rocm()
|
||||
and torch.cuda.is_available()
|
||||
and bool(getattr(envs, "VLLM_USE_OINK_OPS", False))
|
||||
):
|
||||
# NOTE: vLLM disables custom ops by default when using Inductor.
|
||||
# If this op is disabled, CustomOp will dispatch to forward_native,
|
||||
# and the Oink path in forward_cuda will never run.
|
||||
if getattr(self._forward_method, "__func__", None) is getattr(
|
||||
self.forward_native, "__func__", None
|
||||
):
|
||||
try:
|
||||
from vllm.config import get_cached_compilation_config
|
||||
|
||||
custom_ops = get_cached_compilation_config().custom_ops
|
||||
except Exception:
|
||||
custom_ops = ["<unknown>"]
|
||||
logger.warning_once(
|
||||
"VLLM_USE_OINK_OPS=1 but the `rms_norm` custom op is "
|
||||
"disabled (CompilationConfig.custom_ops=%s). Enable it via "
|
||||
"`compilation_config={'custom_ops': ['none', '+rms_norm']}` "
|
||||
"(or `['all']`) to let vLLM call into torch.ops.oink.*.",
|
||||
custom_ops,
|
||||
)
|
||||
# Custom op disabled => forward_cuda won't run. Avoid doing any
|
||||
# external Oink initialization work in this case.
|
||||
else:
|
||||
try:
|
||||
device_index = torch.cuda.current_device()
|
||||
if _oink_ops.is_oink_available_for_device(device_index):
|
||||
self._use_oink_rmsnorm = True
|
||||
self._use_oink_fused_add_rmsnorm = (
|
||||
_oink_ops.has_fused_add_rms_norm()
|
||||
)
|
||||
except Exception as e:
|
||||
# If anything goes wrong (no Oink install, CPU-only env, etc.),
|
||||
# silently fall back to the built-in RMSNorm path.
|
||||
logger.warning_once(
|
||||
"VLLM_USE_OINK_OPS=1 but failed to initialize Oink "
|
||||
"RMSNorm; falling back to vLLM RMSNorm. Error: %s",
|
||||
e,
|
||||
)
|
||||
self._use_oink_rmsnorm = False
|
||||
self._use_oink_fused_add_rmsnorm = False
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
x: torch.Tensor,
|
||||
@@ -202,6 +290,73 @@ class RMSNorm(CustomOp):
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
# Optional Oink SM100 fast path (no residual). This path is
|
||||
# torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
|
||||
# 2D layouts (including padded rows) when using the Oink
|
||||
# pointer-based kernel.
|
||||
if (
|
||||
residual is None
|
||||
and getattr(self, "_use_oink_rmsnorm", False)
|
||||
and x.is_cuda
|
||||
and x.dim() >= 2
|
||||
and self.has_weight
|
||||
and not vllm_is_batch_invariant()
|
||||
and self.weight.data.dtype == x.dtype
|
||||
and self.weight.data.is_contiguous()
|
||||
):
|
||||
orig_shape = x.shape
|
||||
hidden_size = orig_shape[-1]
|
||||
if _can_view_as_2d(x):
|
||||
x_2d = x.view(-1, hidden_size)
|
||||
if _is_oink_stride_compatible_2d(x_2d):
|
||||
y_2d = _oink_ops.rmsnorm(
|
||||
x_2d,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return y_2d.view(orig_shape)
|
||||
|
||||
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
|
||||
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both
|
||||
# `x` (normalized output) and `residual` (residual-out buffer).
|
||||
if (
|
||||
residual is not None
|
||||
and getattr(self, "_use_oink_fused_add_rmsnorm", False)
|
||||
and x.is_cuda
|
||||
and residual.is_cuda
|
||||
and x.shape == residual.shape
|
||||
and x.dtype == residual.dtype
|
||||
and x.dim() >= 2
|
||||
and self.has_weight
|
||||
and not vllm_is_batch_invariant()
|
||||
and self.weight.data.dtype == x.dtype
|
||||
and self.weight.data.is_contiguous()
|
||||
):
|
||||
orig_shape = x.shape
|
||||
hidden_size = orig_shape[-1]
|
||||
if _can_view_as_2d(x) and _can_view_as_2d(residual):
|
||||
x_2d = x.view(-1, hidden_size)
|
||||
res_2d = residual.view(-1, hidden_size)
|
||||
|
||||
# The Oink in-place pointer path supports the common vLLM
|
||||
# layout where:
|
||||
# - `x` may be strided/padded row-major (stride(1) == 1), and
|
||||
# - `residual` is contiguous row-major ([M, N] with stride(0) == N).
|
||||
# If these conditions are not met, fall back to vLLM's built-in
|
||||
# fused kernel.
|
||||
if (
|
||||
_is_oink_stride_compatible_2d(x_2d)
|
||||
and _is_oink_stride_compatible_2d(res_2d)
|
||||
and res_2d.is_contiguous()
|
||||
):
|
||||
_oink_ops.fused_add_rms_norm_(
|
||||
x_2d,
|
||||
res_2d,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
add_residual = residual is not None
|
||||
if add_residual:
|
||||
return fused_add_rms_norm(
|
||||
|
||||
Reference in New Issue
Block a user