fix: gran_k=16 in transform_sf + sm_100a arch for NVFP4 mega_moe
- transform_sf_into_required_layout: add gran_k=16 branch for NVFP4 UE4M3 scales (4 per int32, group_size=16). Previously only handled 32/128. - get_arch: always return '100a' for SM100, never '100f'. The family variant lacks mxf4nvf4 (NVFP4 block-scaled MMA) support, causing 'scale_vec::4X not supported on sm_100f' errors. - transform_nvfp4_weights_for_mega_moe: fold weight_scale_2 into block scales, pack UE4M3→int32, transpose MN-major, call transform_sf_into_required_layout with gran_k=16.
This commit is contained in:
@@ -53,7 +53,10 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
}
|
||||
|
||||
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
|
||||
// gran_k=16: NVFP4 UE4M3 (4 per int32, group_size=16)
|
||||
// gran_k=32: MXFP4 UE8M0 (4 per int32, group_size=32)
|
||||
// gran_k=128: FP32 block scale
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 16 or gran_k == 32 or gran_k == 128) and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
|
||||
@@ -88,10 +88,15 @@ public:
|
||||
std::string get_arch(const bool& number_only = false,
|
||||
const bool& support_arch_family = false) {
|
||||
const auto [major, minor] = get_arch_pair();
|
||||
if (major == 10 and minor != 1) {
|
||||
if (major == 10) {
|
||||
if (number_only)
|
||||
return "100";
|
||||
return support_arch_family ? "100f" : "100a";
|
||||
// Always target 100a for SM100 — the 'f' family variant
|
||||
// lacks mxf4nvf4 (NVFP4 block-scaled MMA) support, which
|
||||
// causes 'scale_vec::4X not supported on sm_100f' errors.
|
||||
// Since DeepGEMM is JIT-compiled for the exact GPU, there's
|
||||
// no benefit to targeting the restricted family subset.
|
||||
return "100a";
|
||||
}
|
||||
return std::to_string(major * 10 + minor) + (number_only ? "" : "a");
|
||||
}
|
||||
|
||||
@@ -138,7 +138,9 @@ def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor]
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l1_weight_scale_2: Optional[torch.Tensor] = None,
|
||||
l2_weight_scale_2: Optional[torch.Tensor] = None
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Transform NVFP4 expert weights for the mega_moe kernel.
|
||||
|
||||
@@ -146,14 +148,67 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
- weight: uint8 E2M1 packed, shape (num_experts, N, K//2)
|
||||
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
|
||||
|
||||
The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout.
|
||||
The kernel expects (weight, packed_sf) where packed_sf is int32 TMA-aligned
|
||||
UTCCP layout with gran_k=16.
|
||||
|
||||
If weight_scale_2 (float32 global scale) is provided, it is folded into the
|
||||
block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3.
|
||||
"""
|
||||
# L1: interleave gate/up, then pack + transpose SF for UTCCP
|
||||
l1_interleaved = _interleave_l1_weights(l1_weights)
|
||||
l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
|
||||
# L2: only pack + transpose SF for UTCCP
|
||||
l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1]))
|
||||
return l1_weights, l2_weights
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
|
||||
if scale_2 is None:
|
||||
return sf
|
||||
sf_f32 = sf.to(torch.float32)
|
||||
if scale_2.dim() == 1:
|
||||
scale_2 = scale_2.view(-1, 1, 1)
|
||||
sf_f32 = sf_f32 * scale_2
|
||||
sf_f32 = sf_f32.clamp(0.0, 448.0)
|
||||
return sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
# Fold global scales into block scales
|
||||
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
|
||||
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
|
||||
|
||||
num_experts = l1_weights[0].shape[0]
|
||||
l1_n = l1_weights[0].shape[1]
|
||||
l1_k = l1_weights[0].shape[2] * 2 # K (weight is K//2 uint8)
|
||||
l2_n = l2_weights[0].shape[1]
|
||||
l2_k = l2_weights[0].shape[2] * 2
|
||||
|
||||
# Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption
|
||||
# 4 UE4M3 bytes → 1 int32
|
||||
def pack_ue4m3_to_int32(sf):
|
||||
sf_u8 = sf.view(torch.uint8)
|
||||
assert sf_u8.shape[-1] % 4 == 0
|
||||
packed = (sf_u8[..., 0::4].to(torch.int32) |
|
||||
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
||||
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
||||
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
||||
return packed.contiguous()
|
||||
|
||||
l1_sf_packed = pack_ue4m3_to_int32(l1_sf)
|
||||
l2_sf_packed = pack_ue4m3_to_int32(l2_sf)
|
||||
|
||||
# Transpose to MN-major layout (stride(-2)=1) and make contiguous
|
||||
# transform_sf_into_required_layout expects MN-major input for TMA stride checks
|
||||
l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
|
||||
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
|
||||
# recipe (1, 16): gran_mn=1, gran_k=16 (NVFP4 native block16)
|
||||
l1_sf_transformed = transform_sf_into_required_layout(
|
||||
l1_sf_mn, l1_n, l1_k, (1, 16), num_experts)
|
||||
l2_sf_transformed = transform_sf_into_required_layout(
|
||||
l2_sf_mn, l2_n, l2_k, (1, 16), num_experts)
|
||||
|
||||
# L1: interleave gate/up
|
||||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_transformed))
|
||||
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
|
||||
l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1])
|
||||
l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed)
|
||||
return l1_out, l2_out
|
||||
|
||||
|
||||
def fp8_fp4_mega_moe(y: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user