[Bugfix] Fix MRoPE dispatch on XPU (#24724)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
@@ -300,6 +300,15 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user