[Bugfix] Support RotaryEmbedding CustomOp for gpt-oss (#33800)
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
This commit is contained in:
@@ -86,14 +86,23 @@ class RotaryEmbeddingBase(CustomOp):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
|
||||
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> torch.Tensor:
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
cos_sin_cache = self.cos_sin_cache
|
||||
if (
|
||||
self.cos_sin_cache.device != query.device
|
||||
or self.cos_sin_cache.dtype != query.dtype
|
||||
cos_sin_cache.device == query.device
|
||||
and self.cos_sin_cache.dtype == query.dtype
|
||||
):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
return cos_sin_cache
|
||||
|
||||
cos_sin_cache = cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
# Avoid mutating buffers during torch.compile (cudagraph) tracing.
|
||||
if torch.compiler.is_compiling():
|
||||
return cos_sin_cache
|
||||
|
||||
self.cos_sin_cache = cos_sin_cache
|
||||
return cos_sin_cache
|
||||
|
||||
def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos_sin = self.cos_sin_cache[:seqlen]
|
||||
@@ -172,13 +181,14 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
return self.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.cos_sin_cache,
|
||||
cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
|
||||
@@ -201,7 +211,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
@@ -210,7 +220,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
@@ -222,12 +232,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.is_rocm_triton_rotary_embed_enabled:
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
rocm_aiter_ops.triton_rotary_embed(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.cos_sin_cache,
|
||||
cos_sin_cache,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.is_neox_style,
|
||||
@@ -249,12 +259,13 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
@@ -267,7 +278,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
@@ -276,7 +287,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
@@ -120,14 +120,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
|
||||
cos_sin = self.cos_sin_cache[
|
||||
cos_sin = cos_sin_cache[
|
||||
torch.add(positions, offsets) if offsets is not None else positions
|
||||
]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
@@ -277,9 +277,9 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos_sin = cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
@@ -329,9 +329,9 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos_sin = cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
|
||||
Reference in New Issue
Block a user