diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp index b404241..f369e9b 100644 --- a/csrc/apis/layout.hpp +++ b/csrc/apis/layout.hpp @@ -53,7 +53,8 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, } // (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major - if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10) + // Supports gran_k=16 (NVFP4), 32 (MXFP4), 128 (FP8) + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 16 or gran_k == 32 or gran_k == 128) and arch_major == 10) return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); DG_HOST_UNREACHABLE("Unknown SF transformation");