diag: force a_format/b_format=5 (MXF8F6F4Format::E2M1), re-enable MMA, dump k=0+k=1

This commit is contained in:
2026-05-12 19:06:28 +00:00
parent 3b8aa5fd4d
commit c1cbe488f3
2 changed files with 31 additions and 3 deletions

View File

@@ -228,7 +228,17 @@ static MegaMoEConfig get_mega_moe_config(
num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads
};
// Print configs for the first time
// Always print block_m for NVFP4 debugging
{
const auto key = fmt::format(
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk);
static std::unordered_set<std::string> printed;
if (printed.count(key) == 0) {
std::cout << "[DIAG-HOST] " << key << ": " << config << std::endl;
printed.insert(key);
}
}
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
const auto key = fmt::format(
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",

View File

@@ -848,6 +848,24 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Dynamic update of UMMA N based on effective M
mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m<true>());
// DIAGNOSTIC: Force-override instr_desc bitfields
// Test 1: Force a_format/b_format to 5 (MXF8F6F4Format::E2M1 encoding)
// MXF4Format::E2M1=1 but MXF8F6F4Format::E2M1=5 — hardware may expect 5
// Test 2: Force scale_format to 1 (E8M0) to see if bit 23 matters
// Test 3: a_sf_id/b_sf_id already set by make_runtime_instr_desc_with_sf_id
{
uint32_t raw = static_cast<uint32_t>(instr_desc);
// Clear a_format [7,10) and b_format [10,13), then OR in 5 for both
raw = (raw & ~((0x7u << 7) | (0x7u << 10))) | (5u << 7) | (5u << 10);
// Force scale_format bit [23] to 1 (E8M0)
// raw |= (1u << 23); // uncomment to test scale_fmt=1
instr_desc = *reinterpret_cast<cute::UMMA::InstrDescriptorBlockScaled*>(&raw);
if (lane_idx == 0) {
printf("[DIAG-FORCE] after override: raw=0x%08x a_fmt=%u b_fmt=%u scale_fmt=%u\n",
raw, (raw >> 7) & 7, (raw >> 10) & 7, (raw >> 23) & 1);
}
}
// Wait tensor memory empty barrier arrival
const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages;
const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1;
@@ -916,7 +934,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2));
// DIAGNOSTIC: dump raw descriptor values on first MMA of first block
if (k_block_idx == 0 && k == 0 && lane_idx == 0) {
if (k_block_idx == 0 && lane_idx == 0) {
uint32_t instr_desc_raw = static_cast<uint32_t>(instr_desc);
uint32_t runtime_instr_desc_hi = static_cast<uint32_t>(runtime_instr_desc >> 32);
uint32_t runtime_instr_desc_lo = static_cast<uint32_t>(runtime_instr_desc & 0xFFFFFFFF);
@@ -943,7 +961,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// DIAGNOSTIC: stub the MMA to test if fault is MMA vs upstream
// #if 0 to skip the FMA, #if 1 to enable it
#if 0
#if 1
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
b_desc, a_desc, accum_stage_idx * UMMA_N,