[Bugfix][CPU] Fix RotaryEmbedding fallback causing gibberish with --enforce-eager (#31643)

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
This commit is contained in:
RickyChen / 陳昭儒
2026-01-06 01:25:38 +08:00
committed by GitHub
parent eefa713a66
commit c455b771fd
2 changed files with 25 additions and 2 deletions

View File

@@ -67,8 +67,9 @@ class CustomOp(nn.Module):
return self.forward_native(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
# By default, we assume that CPU ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the

View File

@@ -250,6 +250,28 @@ class RotaryEmbedding(RotaryEmbeddingBase):
)
return query, key
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from vllm import _custom_ops as ops
self._match_cos_sin_cache_dtype(query)
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"