[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -17,13 +17,13 @@ capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.tensor):
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.tensor):
|
||||
def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user