Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -113,9 +112,7 @@ def test_prepare_next_token_ids():
|
||||
sampled_token_ids_tensor = torch.tensor(
|
||||
sampled_token_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
sampled_token_ids_cpu = [
|
||||
np.array([i for i in seq if i != -1]) for seq in sampled_token_ids
|
||||
]
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(
|
||||
|
||||
Reference in New Issue
Block a user