diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 26be5fd..9a7e2c8 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -228,7 +228,17 @@ static MegaMoEConfig get_mega_moe_config( num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads }; - // Print configs for the first time + // Always print block_m for NVFP4 debugging + { + const auto key = fmt::format( + "MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); + static std::unordered_set printed; + if (printed.count(key) == 0) { + std::cout << "[DIAG-HOST] " << key << ": " << config << std::endl; + printed.insert(key); + } + } if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { const auto key = fmt::format( "MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 0c2e1a6..dccf16f 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -848,6 +848,24 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Dynamic update of UMMA N based on effective M mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + // DIAGNOSTIC: Force-override instr_desc bitfields + // Test 1: Force a_format/b_format to 5 (MXF8F6F4Format::E2M1 encoding) + // MXF4Format::E2M1=1 but MXF8F6F4Format::E2M1=5 — hardware may expect 5 + // Test 2: Force scale_format to 1 (E8M0) to see if bit 23 matters + // Test 3: a_sf_id/b_sf_id already set by make_runtime_instr_desc_with_sf_id + { + uint32_t raw = static_cast(instr_desc); + // Clear a_format [7,10) and b_format [10,13), then OR in 5 for both + raw = (raw & ~((0x7u << 7) | (0x7u << 10))) | (5u << 7) | (5u << 10); + // Force scale_format bit [23] to 1 (E8M0) + // raw |= (1u << 23); // uncomment to test scale_fmt=1 + instr_desc = *reinterpret_cast(&raw); + if (lane_idx == 0) { + printf("[DIAG-FORCE] after override: raw=0x%08x a_fmt=%u b_fmt=%u scale_fmt=%u\n", + raw, (raw >> 7) & 7, (raw >> 10) & 7, (raw >> 23) & 1); + } + } + // Wait tensor memory empty barrier arrival const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; @@ -916,7 +934,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2)); // DIAGNOSTIC: dump raw descriptor values on first MMA of first block - if (k_block_idx == 0 && k == 0 && lane_idx == 0) { + if (k_block_idx == 0 && lane_idx == 0) { uint32_t instr_desc_raw = static_cast(instr_desc); uint32_t runtime_instr_desc_hi = static_cast(runtime_instr_desc >> 32); uint32_t runtime_instr_desc_lo = static_cast(runtime_instr_desc & 0xFFFFFFFF); @@ -943,7 +961,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // DIAGNOSTIC: stub the MMA to test if fault is MMA vs upstream // #if 0 to skip the FMA, #if 1 to enable it -#if 0 +#if 1 // NVFP4: use mxf4nvf4 instruction with UE4M3 scales ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma( b_desc, a_desc, accum_stage_idx * UMMA_N,