tweax n shit
This commit is contained in:
@@ -1095,7 +1095,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
cute::abs(swiglu_values[i * 2 + 0].y)),
|
||||
cute::max(cute::abs(swiglu_values[i * 2 + 1].x),
|
||||
cute::abs(swiglu_values[i * 2 + 1].y)));
|
||||
amax_values[i] = math::warp_reduce<4, true>(lane_amax, math::ReduceMax<float>());
|
||||
amax_values[i] = math::warp_reduce<4, false>(lane_amax, math::ReduceMax<float>());
|
||||
}
|
||||
|
||||
// Wait shared memory release from previous TMA store
|
||||
@@ -1158,7 +1158,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// SF store — NVFP4 group_size=16: all 4 warps warps write, one K position each
|
||||
// k_idx = n_block_idx * 4 + warp_idx_in_wg → 4 K positions per atom
|
||||
if (lane_idx < 4) {
|
||||
// One lane per row: lane_idx%4==0 selects lane 0,4,8,...,28 → rows 0–7
|
||||
if ((lane_idx & 3) == 0) {
|
||||
const uint32_t k_idx = n_block_idx * 4 + warp_idx_in_wg;
|
||||
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
|
||||
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
|
||||
@@ -1166,7 +1167,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
|
||||
__builtin_assume(token_base_idx < BLOCK_M);
|
||||
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
|
||||
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
|
||||
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx / 4) * 4;
|
||||
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * uint32_t(sizeof(uint32_t)) + byte_idx;
|
||||
auto to_ue4m3 = [](float v) -> uint8_t {
|
||||
v = fmaxf(0.0f, fminf(v, 448.0f));
|
||||
@@ -1174,7 +1175,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
return reinterpret_cast<uint8_t&>(e) & 0x7F;
|
||||
};
|
||||
sf_base_ptr[sf_addr] = to_ue4m3(sf_val);
|
||||
sf_base_ptr[sf_addr + 4 * uint32_t(sizeof(uint32_t))] = to_ue4m3(sf_val);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user