Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Cyrus Leung
2025-11-15 14:47:41 +08:00
committed by GitHub
parent 6965ef436f
commit 98b4d389ed
15 changed files with 122 additions and 91 deletions

View File

@@ -3,6 +3,7 @@
from unittest import mock
import numpy as np
import pytest
import torch
@@ -112,7 +113,9 @@ def test_prepare_next_token_ids():
sampled_token_ids_tensor = torch.tensor(
sampled_token_ids, dtype=torch.int32, device=device
)
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
sampled_token_ids_cpu = [
np.array([i for i in seq if i != -1]) for seq in sampled_token_ids
]
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(