[Bugfix][CPU] Fix RotaryEmbedding fallback causing gibberish with --enforce-eager (#31643)
Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
This commit is contained in:
@@ -67,8 +67,9 @@ class CustomOp(nn.Module):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
# By default, we assume that CPU ops are compatible with CUDA ops.
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
# By default, we assume that CPU ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs):
|
||||
# By default, we assume that TPU ops are compatible with the
|
||||
|
||||
@@ -250,6 +250,28 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
)
|
||||
return query, key
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
|
||||
Reference in New Issue
Block a user