Fix nvfp4 swizzling (#23140)
Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -552,8 +552,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M, K)
|
||||
return swizzled.reshape(B, M, K)
|
||||
return swizzled.reshape(M_padded, K_padded)
|
||||
return swizzled.reshape(B, M_padded, K_padded)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
|
||||
Reference in New Issue
Block a user