[Misc] Update type annotation for rotary embedding base (#18914)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -141,7 +141,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
cache_dtype: torch.dtype,
|
||||
) -> None:
|
||||
@@ -155,10 +155,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
|
||||
cache = self._compute_cos_sin_cache().to(cache_dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(
|
||||
self,
|
||||
base: Union[int, float],
|
||||
) -> torch.Tensor:
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
|
||||
Reference in New Issue
Block a user