[Misc] Update type annotation for rotary embedding base (#18914)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-30 10:17:01 +08:00
committed by GitHub
parent d54af615d5
commit 1aa2f81b43
4 changed files with 23 additions and 26 deletions

View File

@@ -70,7 +70,7 @@ def test_rotary_embedding(
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
base: float = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
@@ -135,7 +135,7 @@ def test_batched_rotary_embedding(
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
base: float = 10000,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
@@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora(
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
base: float = 10000,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)