From 4292c90a2a188121ccbfd132def62031283d9d8a Mon Sep 17 00:00:00 2001 From: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Date: Wed, 4 Feb 2026 21:17:41 +0100 Subject: [PATCH] [Bugfix] Support `RotaryEmbedding` CustomOp for gpt-oss (#33800) Signed-off-by: simondanielsson --- .../compile/test_rotary_embedding_compile.py | 68 +++++++++++++++++++ .../layers/rotary_embedding/base.py | 35 ++++++---- .../rotary_embedding/deepseek_scaling_rope.py | 4 +- .../layers/rotary_embedding/mrope.py | 8 +-- 4 files changed, 97 insertions(+), 18 deletions(-) create mode 100644 tests/compile/test_rotary_embedding_compile.py diff --git a/tests/compile/test_rotary_embedding_compile.py b/tests/compile/test_rotary_embedding_compile.py new file mode 100644 index 000000000..76f538253 --- /dev/null +++ b/tests/compile/test_rotary_embedding_compile.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.envs as envs +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CompilationConfig, + ModelConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.config.compilation import CompilationMode, CUDAGraphMode +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform + + +@support_torch_compile +class RotaryEmbeddingCompileModule(torch.nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + self.rotary_emb = get_rope( + head_size=32, + max_position=128, + dtype=torch.float32, + rope_parameters={"rope_type": "default", "rope_theta": 10000}, + is_neox_style=True, + ) + + def forward( + self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor + ) -> torch.Tensor: + q_rot, k_rot = self.rotary_emb(positions, query, key) + return q_rot + k_rot + + +@pytest.mark.skipif(current_platform.is_cpu(), reason="Requires GPU for torch.compile") +def test_rotary_embedding_torch_compile_with_custom_op(monkeypatch): + # Ensure env toggles take effect for this test only. + # The bytecode hook is required to detect buffer mutation in compiled code, + # and AOT compile bypasses that hook entirely. + envs.disable_envs_cache() + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1") + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0") + + device = "cuda" + positions = torch.arange(16, device=device) + query = torch.randn(16, 32, device=device, dtype=torch.bfloat16) + key = torch.randn(16, 32, device=device, dtype=torch.bfloat16) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=torch.bfloat16), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + backend="inductor", + custom_ops=["+rotary_embedding"], + cudagraph_mode=CUDAGraphMode.NONE, + cudagraph_num_of_warmups=0, + ), + ) + + with set_current_vllm_config(vllm_config): + model = RotaryEmbeddingCompileModule(vllm_config=vllm_config) + model(positions, query, key) + assert model._compiled_bytecode is not None + assert "update" not in model._compiled_bytecode.co_names diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 2147e00d2..1e3063392 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -86,14 +86,23 @@ class RotaryEmbeddingBase(CustomOp): cache = torch.cat((cos, sin), dim=-1) return cache - def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> torch.Tensor: # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) # is expensive, so avoid calling it if possible + cos_sin_cache = self.cos_sin_cache if ( - self.cos_sin_cache.device != query.device - or self.cos_sin_cache.dtype != query.dtype + cos_sin_cache.device == query.device + and self.cos_sin_cache.dtype == query.dtype ): - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + return cos_sin_cache + + cos_sin_cache = cos_sin_cache.to(query.device, dtype=query.dtype) + # Avoid mutating buffers during torch.compile (cudagraph) tracing. + if torch.compiler.is_compiling(): + return cos_sin_cache + + self.cos_sin_cache = cos_sin_cache + return cos_sin_cache def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]: cos_sin = self.cos_sin_cache[:seqlen] @@ -172,13 +181,14 @@ class RotaryEmbedding(RotaryEmbeddingBase): key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" + cos_sin_cache = self._match_cos_sin_cache_dtype(query) return self.forward_static( positions, query, key, self.head_size, self.rotary_dim, - self.cos_sin_cache, + cos_sin_cache, self.is_neox_style, ) @@ -201,7 +211,7 @@ class RotaryEmbedding(RotaryEmbeddingBase): from vllm import _custom_ops as ops - self._match_cos_sin_cache_dtype(query) + cos_sin_cache = self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. @@ -210,7 +220,7 @@ class RotaryEmbedding(RotaryEmbeddingBase): query, key, self.head_size, - self.cos_sin_cache, + cos_sin_cache, self.is_neox_style, ) return query, key @@ -222,12 +232,12 @@ class RotaryEmbedding(RotaryEmbeddingBase): key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.is_rocm_triton_rotary_embed_enabled: - self._match_cos_sin_cache_dtype(query) + cos_sin_cache = self._match_cos_sin_cache_dtype(query) rocm_aiter_ops.triton_rotary_embed( positions, query, key, - self.cos_sin_cache, + cos_sin_cache, self.head_size, self.rotary_dim, self.is_neox_style, @@ -249,12 +259,13 @@ class RotaryEmbedding(RotaryEmbeddingBase): else: from vllm import _custom_ops as ops + cos_sin_cache = self._match_cos_sin_cache_dtype(query) ops.rotary_embedding( positions, query, key, self.head_size, - self.cos_sin_cache, + cos_sin_cache, self.is_neox_style, ) return query, key @@ -267,7 +278,7 @@ class RotaryEmbedding(RotaryEmbeddingBase): ) -> tuple[torch.Tensor, torch.Tensor | None]: from vllm import _custom_ops as ops - self._match_cos_sin_cache_dtype(query) + cos_sin_cache = self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. @@ -276,7 +287,7 @@ class RotaryEmbedding(RotaryEmbeddingBase): query, key, self.head_size, - self.cos_sin_cache, + cos_sin_cache, self.is_neox_style, ) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 9be9caacb..c3abdc156 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -120,14 +120,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): ) -> tuple[torch.Tensor, torch.Tensor | None]: """PyTorch-native implementation equivalent to forward().""" assert key is not None - self._match_cos_sin_cache_dtype(query) + cos_sin_cache = self._match_cos_sin_cache_dtype(query) query_rot = query[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: query_pass = query[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :] - cos_sin = self.cos_sin_cache[ + cos_sin = cos_sin_cache[ torch.add(positions, offsets) if offsets is not None else positions ] cos, sin = cos_sin.chunk(2, dim=-1) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a74bf092b..52f3c333d 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -277,9 +277,9 @@ class MRotaryEmbedding(RotaryEmbeddingBase): assert positions.ndim == 1 or positions.ndim == 2 assert key is not None - self._match_cos_sin_cache_dtype(query) + cos_sin_cache = self._match_cos_sin_cache_dtype(query) num_tokens = positions.shape[-1] - cos_sin = self.cos_sin_cache[positions] + cos_sin = cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section @@ -329,9 +329,9 @@ class MRotaryEmbedding(RotaryEmbeddingBase): assert positions.ndim == 1 or positions.ndim == 2 assert key is not None - self._match_cos_sin_cache_dtype(query) + cos_sin_cache = self._match_cos_sin_cache_dtype(query) num_tokens = positions.shape[-1] - cos_sin = self.cos_sin_cache[positions] + cos_sin = cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape key_shape = key.shape