[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
expected_output = [
|
||||
expected_output: List[List[int]] = [
|
||||
[],
|
||||
]
|
||||
for i in range(proposal_token_ids.shape[0]):
|
||||
expected_output.append(proposal_token_ids[:i + 1].tolist())
|
||||
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
|
||||
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
|
||||
|
||||
actual_output = [
|
||||
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
|
||||
|
||||
Reference in New Issue
Block a user