fix: add gran_k=16 (NVFP4) support to transform_sf_into_required_layout

The C++ function only handled gran_k=32 and 128 (MXFP4/FP8).
Added gran_k=16 for NVFP4 group_size=16 support.
This commit is contained in:
2026-05-11 07:13:00 +00:00
parent 388fd8dcfd
commit f98c1f7fd5

View File

@@ -53,7 +53,8 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
}
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
// Supports gran_k=16 (NVFP4), 32 (MXFP4), 128 (FP8)
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 16 or gran_k == 32 or gran_k == 128) and arch_major == 10)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
DG_HOST_UNREACHABLE("Unknown SF transformation");