[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -86,9 +86,9 @@ def dequant_no_scale(
|
||||
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
|
||||
# the generic pytorch version.
|
||||
# Just visual comparison.
|
||||
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
|
||||
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
||||
|
||||
n = parts.sum().item()
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
@@ -204,7 +204,7 @@ def main():
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
|
||||
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
|
||||
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
|
||||
methods):
|
||||
|
||||
# I didn't see visible improvements from increasing these, but feel free :)
|
||||
@@ -252,10 +252,10 @@ def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
|
||||
print('')
|
||||
|
||||
|
||||
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
|
||||
def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
|
||||
nbooks: int, bits: int, method) -> float:
|
||||
|
||||
n = parts.sum().item()
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user