diag: force a_format/b_format=5 (MXF8F6F4Format::E2M1), re-enable MMA, dump k=0+k=1
This commit is contained in:
@@ -228,7 +228,17 @@ static MegaMoEConfig get_mega_moe_config(
|
||||
num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads
|
||||
};
|
||||
|
||||
// Print configs for the first time
|
||||
// Always print block_m for NVFP4 debugging
|
||||
{
|
||||
const auto key = fmt::format(
|
||||
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
|
||||
num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk);
|
||||
static std::unordered_set<std::string> printed;
|
||||
if (printed.count(key) == 0) {
|
||||
std::cout << "[DIAG-HOST] " << key << ": " << config << std::endl;
|
||||
printed.insert(key);
|
||||
}
|
||||
}
|
||||
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
|
||||
const auto key = fmt::format(
|
||||
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
|
||||
|
||||
@@ -848,6 +848,24 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// Dynamic update of UMMA N based on effective M
|
||||
mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m<true>());
|
||||
|
||||
// DIAGNOSTIC: Force-override instr_desc bitfields
|
||||
// Test 1: Force a_format/b_format to 5 (MXF8F6F4Format::E2M1 encoding)
|
||||
// MXF4Format::E2M1=1 but MXF8F6F4Format::E2M1=5 — hardware may expect 5
|
||||
// Test 2: Force scale_format to 1 (E8M0) to see if bit 23 matters
|
||||
// Test 3: a_sf_id/b_sf_id already set by make_runtime_instr_desc_with_sf_id
|
||||
{
|
||||
uint32_t raw = static_cast<uint32_t>(instr_desc);
|
||||
// Clear a_format [7,10) and b_format [10,13), then OR in 5 for both
|
||||
raw = (raw & ~((0x7u << 7) | (0x7u << 10))) | (5u << 7) | (5u << 10);
|
||||
// Force scale_format bit [23] to 1 (E8M0)
|
||||
// raw |= (1u << 23); // uncomment to test scale_fmt=1
|
||||
instr_desc = *reinterpret_cast<cute::UMMA::InstrDescriptorBlockScaled*>(&raw);
|
||||
if (lane_idx == 0) {
|
||||
printf("[DIAG-FORCE] after override: raw=0x%08x a_fmt=%u b_fmt=%u scale_fmt=%u\n",
|
||||
raw, (raw >> 7) & 7, (raw >> 10) & 7, (raw >> 23) & 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait tensor memory empty barrier arrival
|
||||
const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages;
|
||||
const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1;
|
||||
@@ -916,7 +934,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2));
|
||||
|
||||
// DIAGNOSTIC: dump raw descriptor values on first MMA of first block
|
||||
if (k_block_idx == 0 && k == 0 && lane_idx == 0) {
|
||||
if (k_block_idx == 0 && lane_idx == 0) {
|
||||
uint32_t instr_desc_raw = static_cast<uint32_t>(instr_desc);
|
||||
uint32_t runtime_instr_desc_hi = static_cast<uint32_t>(runtime_instr_desc >> 32);
|
||||
uint32_t runtime_instr_desc_lo = static_cast<uint32_t>(runtime_instr_desc & 0xFFFFFFFF);
|
||||
@@ -943,7 +961,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
|
||||
// DIAGNOSTIC: stub the MMA to test if fault is MMA vs upstream
|
||||
// #if 0 to skip the FMA, #if 1 to enable it
|
||||
#if 0
|
||||
#if 1
|
||||
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
|
||||
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
|
||||
b_desc, a_desc, accum_stage_idx * UMMA_N,
|
||||
|
||||
Reference in New Issue
Block a user