[Perf] Add opt-in SM100 Oink RMSNorm custom-op path (#31828)

Signed-off-by: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
Laura Wang
2026-02-24 23:01:53 -08:00
committed by GitHub
parent cd43673668
commit 2465071510
4 changed files with 331 additions and 0 deletions

View File

@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
import pytest
import torch
def _load_oink_ops_module():
# Import the module normally (vllm is installed as an editable package in CI).
from vllm import _oink_ops
return _oink_ops
def test_oink_availability_checks(monkeypatch: pytest.MonkeyPatch):
_oink_ops = _load_oink_ops_module()
# Ensure the ops namespace exists and is mutable for tests.
monkeypatch.setattr(
torch.ops,
"oink",
types.SimpleNamespace(rmsnorm=lambda x, w, eps: x),
raising=False,
)
# Case 1: CUDA not available.
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
assert _oink_ops.is_oink_available_for_device(0) is False
# Case 2: CUDA available but < SM100.
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda idx: (9, 0))
assert _oink_ops.is_oink_available_for_device(0) is False
# Case 3: CUDA available and SM100, rmsnorm op registered.
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda idx: (10, 0))
assert _oink_ops.is_oink_available_for_device(0) is True
# fused op presence probe
assert _oink_ops.has_fused_add_rms_norm() is False
monkeypatch.setattr(
torch.ops,
"oink",
types.SimpleNamespace(
rmsnorm=lambda x, w, eps: x,
fused_add_rms_norm=lambda x, residual, w, eps: None,
),
raising=False,
)
assert _oink_ops.has_fused_add_rms_norm() is True
def test_can_view_as_2d_stride_guard():
# Import the helper from the layernorm module.
from vllm.model_executor.layers.layernorm import _can_view_as_2d
x = torch.zeros((2, 3, 4))
assert _can_view_as_2d(x) is True
# Size-1 dims should be ignored by the viewability check.
# Create a tensor where stride(0) != stride(1) * size(1) due to padding,
# but view(-1, H) is still valid because dim 1 has size 1.
base = torch.zeros((2, 10, 4))
x_singleton = base[:, :1, :]
x_singleton.view(-1, x_singleton.shape[-1])
assert _can_view_as_2d(x_singleton) is True
# Middle-dimension stride break: view(-1, hidden) should be invalid.
x2 = x[:, ::2, :]
with pytest.raises(RuntimeError):
x2.view(-1, x2.shape[-1])
assert _can_view_as_2d(x2) is False

96
vllm/_oink_ops.py Normal file
View File

@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Small helper wrappers for external Oink Blackwell custom ops.
vLLM does not depend on the external Oink repository/package. When an external
plugin registers torch.library.custom_op entrypoints under the `oink::`
namespace (e.g. via vLLM's general_plugins mechanism) and
`VLLM_USE_OINK_OPS=1` is set, vLLM can route eligible calls to those ops.
This module provides:
- A single place to probe Oink op availability at module init time
(outside torch.compile tracing), and
- Thin wrappers around the torch.ops entrypoints for use in CUDA fast paths,
without introducing graph breaks.
Important:
Do not call the availability helpers in a compiled region. They may call
functions decorated with `torch._dynamo.disable` to safely check
conditions that should not be traced.
"""
from __future__ import annotations
from collections.abc import Callable
import torch
try:
from torch._dynamo import disable as _dynamo_disable # type: ignore[attr-defined]
except Exception: # pragma: no cover
def _dynamo_disable(fn: Callable): # type: ignore[misc]
return fn
def _has_oink_op(op_name: str) -> bool:
"""Check if a specific oink op is registered."""
return hasattr(torch.ops, "oink") and hasattr(torch.ops.oink, op_name)
@_dynamo_disable
def is_oink_available_for_device(device_index: int) -> bool:
"""Return True if Oink ops are registered and device is SM100+.
This function is intended to be called during module initialization
(e.g., in RMSNorm.__init__), not in the forward path.
External plugins are expected to gate registration on SM100+ and
VLLM_USE_OINK_OPS=1, so if the ops are present they should be usable.
"""
if not torch.cuda.is_available():
return False
try:
major, minor = torch.cuda.get_device_capability(device_index)
sm = 10 * major + minor
if sm < 100:
return False
except Exception:
return False
return _has_oink_op("rmsnorm")
def has_fused_add_rms_norm() -> bool:
"""Return True if the in-place fused op is registered."""
return _has_oink_op("fused_add_rms_norm")
def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
"""Call `torch.ops.oink.rmsnorm`.
This wrapper is safe to call in torch.compile regions.
"""
return torch.ops.oink.rmsnorm(x, weight, eps)
def fused_add_rms_norm_(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> None:
"""Call `torch.ops.oink.fused_add_rms_norm` (mutates x and residual)."""
torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps)
def fused_add_rms_norm(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convenience wrapper returning (x, residual) after in-place mutation."""
fused_add_rms_norm_(x, residual, weight, eps)
return x, residual

View File

@@ -97,6 +97,7 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_DISABLE_PYNCCL: bool = False
VLLM_USE_OINK_OPS: bool = False
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
@@ -896,6 +897,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_PYNCCL": lambda: (
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
),
# Optional: enable external Oink custom ops (e.g., Blackwell RMSNorm).
# Disabled by default.
"VLLM_USE_OINK_OPS": lambda: (
os.getenv("VLLM_USE_OINK_OPS", "False").lower() in ("true", "1")
),
# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER": lambda: (

View File

@@ -6,7 +6,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import _oink_ops, envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
@@ -14,6 +16,41 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
def _can_view_as_2d(x: torch.Tensor) -> bool:
"""Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
if x.dim() < 2:
return False
if x.dim() == 2:
return True
# For a view(-1, N) to be valid, all leading dims must be contiguous with
# respect to each other (size-1 dims are ignored).
for dim in range(x.dim() - 1):
# Strides for size-1 dims are irrelevant and can be arbitrary.
if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
dim + 1
):
return False
return True
def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
"""Return True if x_2d meets Oink's pointer-path stride constraints."""
if x_2d.dim() != 2:
return False
if x_2d.stride(1) != 1:
return False
# Match Oink's vectorization constraint: stride(0) divisible by 256b.
if x_2d.dtype in (torch.float16, torch.bfloat16):
divby = 16
elif x_2d.dtype == torch.float32:
divby = 8
else:
return False
return (x_2d.stride(0) % divby) == 0
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
@@ -131,6 +168,57 @@ class RMSNorm(CustomOp):
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
)
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on
# compatible CUDA devices (e.g., SM100) when the external Oink
# package is available. This is detected once at construction time
# to avoid per-call device queries in the hot path.
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False
if (
not current_platform.is_rocm()
and torch.cuda.is_available()
and bool(getattr(envs, "VLLM_USE_OINK_OPS", False))
):
# NOTE: vLLM disables custom ops by default when using Inductor.
# If this op is disabled, CustomOp will dispatch to forward_native,
# and the Oink path in forward_cuda will never run.
if getattr(self._forward_method, "__func__", None) is getattr(
self.forward_native, "__func__", None
):
try:
from vllm.config import get_cached_compilation_config
custom_ops = get_cached_compilation_config().custom_ops
except Exception:
custom_ops = ["<unknown>"]
logger.warning_once(
"VLLM_USE_OINK_OPS=1 but the `rms_norm` custom op is "
"disabled (CompilationConfig.custom_ops=%s). Enable it via "
"`compilation_config={'custom_ops': ['none', '+rms_norm']}` "
"(or `['all']`) to let vLLM call into torch.ops.oink.*.",
custom_ops,
)
# Custom op disabled => forward_cuda won't run. Avoid doing any
# external Oink initialization work in this case.
else:
try:
device_index = torch.cuda.current_device()
if _oink_ops.is_oink_available_for_device(device_index):
self._use_oink_rmsnorm = True
self._use_oink_fused_add_rmsnorm = (
_oink_ops.has_fused_add_rms_norm()
)
except Exception as e:
# If anything goes wrong (no Oink install, CPU-only env, etc.),
# silently fall back to the built-in RMSNorm path.
logger.warning_once(
"VLLM_USE_OINK_OPS=1 but failed to initialize Oink "
"RMSNorm; falling back to vLLM RMSNorm. Error: %s",
e,
)
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False
@staticmethod
def forward_static(
x: torch.Tensor,
@@ -202,6 +290,73 @@ class RMSNorm(CustomOp):
if self.variance_size_override is not None:
return self.forward_native(x, residual)
# Optional Oink SM100 fast path (no residual). This path is
# torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
# 2D layouts (including padded rows) when using the Oink
# pointer-based kernel.
if (
residual is None
and getattr(self, "_use_oink_rmsnorm", False)
and x.is_cuda
and x.dim() >= 2
and self.has_weight
and not vllm_is_batch_invariant()
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
orig_shape = x.shape
hidden_size = orig_shape[-1]
if _can_view_as_2d(x):
x_2d = x.view(-1, hidden_size)
if _is_oink_stride_compatible_2d(x_2d):
y_2d = _oink_ops.rmsnorm(
x_2d,
self.weight.data,
self.variance_epsilon,
)
return y_2d.view(orig_shape)
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both
# `x` (normalized output) and `residual` (residual-out buffer).
if (
residual is not None
and getattr(self, "_use_oink_fused_add_rmsnorm", False)
and x.is_cuda
and residual.is_cuda
and x.shape == residual.shape
and x.dtype == residual.dtype
and x.dim() >= 2
and self.has_weight
and not vllm_is_batch_invariant()
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
orig_shape = x.shape
hidden_size = orig_shape[-1]
if _can_view_as_2d(x) and _can_view_as_2d(residual):
x_2d = x.view(-1, hidden_size)
res_2d = residual.view(-1, hidden_size)
# The Oink in-place pointer path supports the common vLLM
# layout where:
# - `x` may be strided/padded row-major (stride(1) == 1), and
# - `residual` is contiguous row-major ([M, N] with stride(0) == N).
# If these conditions are not met, fall back to vLLM's built-in
# fused kernel.
if (
_is_oink_stride_compatible_2d(x_2d)
and _is_oink_stride_compatible_2d(res_2d)
and res_2d.is_contiguous()
):
_oink_ops.fused_add_rms_norm_(
x_2d,
res_2d,
self.weight.data,
self.variance_epsilon,
)
return x, residual
add_residual = residual is not None
if add_residual:
return fused_add_rms_norm(