fix: add gran_k=16 (NVFP4) support to transform_sf_into_required_layout
The C++ function only handled gran_k=32 and 128 (MXFP4/FP8). Added gran_k=16 for NVFP4 group_size=16 support.
This commit is contained in:
@@ -53,7 +53,8 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
}
|
||||
|
||||
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
|
||||
// Supports gran_k=16 (NVFP4), 32 (MXFP4), 128 (FP8)
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 16 or gran_k == 32 or gran_k == 128) and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
|
||||
Reference in New Issue
Block a user