Fix nvfp4 swizzling (#23140)

Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Yi Liu
2025-08-22 00:54:50 +08:00
committed by GitHub
parent a482e4e769
commit 0278f1ac3a
2 changed files with 5 additions and 26 deletions

View File

@@ -552,8 +552,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M, K)
return swizzled.reshape(B, M, K)
return swizzled.reshape(M_padded, K_padded)
return swizzled.reshape(B, M_padded, K_padded)
def cutlass_fp4_supported() -> bool: