[Misc] Update type annotation for rotary embedding base (#18914)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -70,7 +70,7 @@ def test_rotary_embedding(
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
@@ -135,7 +135,7 @@ def test_batched_rotary_embedding(
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
@@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
Reference in New Issue
Block a user