[Bugfix] Support RotaryEmbedding CustomOp for gpt-oss (#33800)

Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
This commit is contained in:
Simon Danielsson
2026-02-04 21:17:41 +01:00
committed by GitHub
parent 6e98f6d8b6
commit 4292c90a2a
4 changed files with 97 additions and 18 deletions

View File

@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import vllm.envs as envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
ModelConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
@support_torch_compile
class RotaryEmbeddingCompileModule(torch.nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
self.rotary_emb = get_rope(
head_size=32,
max_position=128,
dtype=torch.float32,
rope_parameters={"rope_type": "default", "rope_theta": 10000},
is_neox_style=True,
)
def forward(
self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor
) -> torch.Tensor:
q_rot, k_rot = self.rotary_emb(positions, query, key)
return q_rot + k_rot
@pytest.mark.skipif(current_platform.is_cpu(), reason="Requires GPU for torch.compile")
def test_rotary_embedding_torch_compile_with_custom_op(monkeypatch):
# Ensure env toggles take effect for this test only.
# The bytecode hook is required to detect buffer mutation in compiled code,
# and AOT compile bypasses that hook entirely.
envs.disable_envs_cache()
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0")
device = "cuda"
positions = torch.arange(16, device=device)
query = torch.randn(16, 32, device=device, dtype=torch.bfloat16)
key = torch.randn(16, 32, device=device, dtype=torch.bfloat16)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=torch.bfloat16),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
backend="inductor",
custom_ops=["+rotary_embedding"],
cudagraph_mode=CUDAGraphMode.NONE,
cudagraph_num_of_warmups=0,
),
)
with set_current_vllm_config(vllm_config):
model = RotaryEmbeddingCompileModule(vllm_config=vllm_config)
model(positions, query, key)
assert model._compiled_bytecode is not None
assert "update" not in model._compiled_bytecode.co_names

View File

@@ -86,14 +86,23 @@ class RotaryEmbeddingBase(CustomOp):
cache = torch.cat((cos, sin), dim=-1)
return cache
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> torch.Tensor:
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
cos_sin_cache = self.cos_sin_cache
if (
self.cos_sin_cache.device != query.device
or self.cos_sin_cache.dtype != query.dtype
cos_sin_cache.device == query.device
and self.cos_sin_cache.dtype == query.dtype
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
return cos_sin_cache
cos_sin_cache = cos_sin_cache.to(query.device, dtype=query.dtype)
# Avoid mutating buffers during torch.compile (cudagraph) tracing.
if torch.compiler.is_compiling():
return cos_sin_cache
self.cos_sin_cache = cos_sin_cache
return cos_sin_cache
def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[:seqlen]
@@ -172,13 +181,14 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
return self.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
self.cos_sin_cache,
cos_sin_cache,
self.is_neox_style,
)
@@ -201,7 +211,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
from vllm import _custom_ops as ops
self._match_cos_sin_cache_dtype(query)
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
@@ -210,7 +220,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query,
key,
self.head_size,
self.cos_sin_cache,
cos_sin_cache,
self.is_neox_style,
)
return query, key
@@ -222,12 +232,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.is_rocm_triton_rotary_embed_enabled:
self._match_cos_sin_cache_dtype(query)
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
rocm_aiter_ops.triton_rotary_embed(
positions,
query,
key,
self.cos_sin_cache,
cos_sin_cache,
self.head_size,
self.rotary_dim,
self.is_neox_style,
@@ -249,12 +259,13 @@ class RotaryEmbedding(RotaryEmbeddingBase):
else:
from vllm import _custom_ops as ops
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
cos_sin_cache,
self.is_neox_style,
)
return query, key
@@ -267,7 +278,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
) -> tuple[torch.Tensor, torch.Tensor | None]:
from vllm import _custom_ops as ops
self._match_cos_sin_cache_dtype(query)
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
@@ -276,7 +287,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query,
key,
self.head_size,
self.cos_sin_cache,
cos_sin_cache,
self.is_neox_style,
)
return query, key

View File

@@ -120,14 +120,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
self._match_cos_sin_cache_dtype(query)
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
cos_sin = self.cos_sin_cache[
cos_sin = cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)

View File

@@ -277,9 +277,9 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
self._match_cos_sin_cache_dtype(query)
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
@@ -329,9 +329,9 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
self._match_cos_sin_cache_dtype(query)
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape