[Misc] Update type annotation for rotary embedding base (#18914)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-30 10:17:01 +08:00
committed by GitHub
parent d54af615d5
commit 1aa2f81b43
4 changed files with 23 additions and 26 deletions

View File

@@ -141,7 +141,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
base: float,
is_neox_style: bool,
cache_dtype: torch.dtype,
) -> None:
@@ -155,10 +155,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
cache = self._compute_cos_sin_cache().to(cache_dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(
self,
base: Union[int, float],
) -> torch.Tensor:
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))