Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor (#38148)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -56,11 +56,11 @@ def create_fp4_scale_tensor(
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
return torch.empty(
|
||||
return torch.zeros(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
return torch.empty((m, n // block_size), device=device, dtype=torch.uint8)
|
||||
return torch.zeros((m, n // block_size), device=device, dtype=torch.uint8)
|
||||
|
||||
|
||||
def create_fp4_output_tensors(
|
||||
|
||||
Reference in New Issue
Block a user