From ad335c38fbfb68717ca5a080fbe81f22e23a8bed Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 23:16:44 +0000 Subject: [PATCH] tweax n shit --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 8 ++--- deep_gemm/mega/__init__.py | 29 ------------------- 2 files changed, 4 insertions(+), 33 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 88221b0..da183c3 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -1095,7 +1095,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, cute::abs(swiglu_values[i * 2 + 0].y)), cute::max(cute::abs(swiglu_values[i * 2 + 1].x), cute::abs(swiglu_values[i * 2 + 1].y))); - amax_values[i] = math::warp_reduce<4, true>(lane_amax, math::ReduceMax()); + amax_values[i] = math::warp_reduce<4, false>(lane_amax, math::ReduceMax()); } // Wait shared memory release from previous TMA store @@ -1158,7 +1158,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // SF store — NVFP4 group_size=16: all 4 warps warps write, one K position each // k_idx = n_block_idx * 4 + warp_idx_in_wg → 4 K positions per atom - if (lane_idx < 4) { + // One lane per row: lane_idx%4==0 selects lane 0,4,8,...,28 → rows 0–7 + if ((lane_idx & 3) == 0) { const uint32_t k_idx = n_block_idx * 4 + warp_idx_in_wg; const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); @@ -1166,7 +1167,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; __builtin_assume(token_base_idx < BLOCK_M); const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M - + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx / 4) * 4; const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * uint32_t(sizeof(uint32_t)) + byte_idx; auto to_ue4m3 = [](float v) -> uint8_t { v = fmaxf(0.0f, fminf(v, 448.0f)); @@ -1174,7 +1175,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, return reinterpret_cast(e) & 0x7F; }; sf_base_ptr[sf_addr] = to_ue4m3(sf_val); - sf_base_ptr[sf_addr + 4 * uint32_t(sizeof(uint32_t))] = to_ue4m3(sf_val); } __syncwarp(); } diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 716daec..e6a1ef7 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -187,35 +187,6 @@ def transform_weights_for_mega_moe( return l1_weights, l2_weights -def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: - """Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout. - - NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X). - The UTCCP layout packs 4 consecutive scale bytes into each int32, - then applies the 4x32 transpose for TMA consumption. - - Input: (num_experts, mn, K//16) float8_e4m3fn scales - Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales - """ - num_groups, mn, sf_k = sf.shape - assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}" - assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}" - - # View as uint8 and pack 4 consecutive bytes into int32 - sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k) - # Pack: every 4 uint8 → 1 int32 - packed = (sf_uint8[..., 0::4].to(torch.int32) | - (sf_uint8[..., 1::4].to(torch.int32) << 8) | - (sf_uint8[..., 2::4].to(torch.int32) << 16) | - (sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4) - - # Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined - # by the 128-element alignment, not the scale vector size) - packed_sf_k = sf_k // 4 - result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k) - .transpose(2, 3) - .reshape(num_groups, mn, packed_sf_k)) - return torch.empty_like(packed).copy_(result) def transform_nvfp4_weights_for_mega_moe(