diag: stub MMA + dump descriptors for ILLEGAL_INSTRUCTION debug
This commit is contained in:
@@ -914,11 +914,42 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2));
|
||||
b_desc.lo = mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2));
|
||||
|
||||
// DIAGNOSTIC: dump raw descriptor values on first MMA of first block
|
||||
if (k_block_idx == 0 && k == 0 && lane_idx == 0) {
|
||||
uint32_t instr_desc_raw = static_cast<uint32_t>(instr_desc);
|
||||
uint32_t runtime_instr_desc_hi = static_cast<uint32_t>(runtime_instr_desc >> 32);
|
||||
uint32_t runtime_instr_desc_lo = static_cast<uint32_t>(runtime_instr_desc & 0xFFFFFFFF);
|
||||
// Decode key bitfields from instr_desc_raw
|
||||
uint32_t a_sf_id = (instr_desc_raw >> 29) & 0x3;
|
||||
uint32_t b_sf_id = (instr_desc_raw >> 4) & 0x3;
|
||||
uint32_t scale_fmt = (instr_desc_raw >> 23) & 0x1;
|
||||
uint32_t a_format = (instr_desc_raw >> 7) & 0x7;
|
||||
uint32_t b_format = (instr_desc_raw >> 10) & 0x7;
|
||||
uint32_t k_size = (instr_desc_raw >> 31) & 0x1;
|
||||
uint32_t m_dim = (instr_desc_raw >> 24) & 0x1F;
|
||||
uint32_t n_dim = (instr_desc_raw >> 17) & 0x3F;
|
||||
uint32_t a_major = (instr_desc_raw >> 15) & 0x1;
|
||||
uint32_t b_major = (instr_desc_raw >> 16) & 0x1;
|
||||
printf("[DIAG] a_desc: lo=0x%08x hi=0x%08x\n", a_desc.lo, a_desc.hi);
|
||||
printf("[DIAG] b_desc: lo=0x%08x hi=0x%08x\n", b_desc.lo, b_desc.hi);
|
||||
printf("[DIAG] instr_desc_raw: 0x%08x runtime_hi: 0x%08x runtime_lo: 0x%08x\n",
|
||||
instr_desc_raw, runtime_instr_desc_hi, runtime_instr_desc_lo);
|
||||
printf("[DIAG] a_format=%u b_format=%u scale_fmt=%u a_sf_id=%u b_sf_id=%u k_size=%u m_dim=%u n_dim=%u a_major=%u b_major=%u\n",
|
||||
a_format, b_format, scale_fmt, a_sf_id, b_sf_id, k_size, m_dim, n_dim, a_major, b_major);
|
||||
printf("[DIAG] k=%u sf_ids(a=%u,b=%u) tmem_sfa=%u tmem_sfb=%u accum_stage=%u\n",
|
||||
k, k, k, kTmemStartColOfSFA, kTmemStartColOfSFB, accum_stage_idx);
|
||||
}
|
||||
|
||||
// DIAGNOSTIC: stub the MMA to test if fault is MMA vs upstream
|
||||
// #if 0 to skip the FMA, #if 1 to enable it
|
||||
#if 0
|
||||
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
|
||||
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
|
||||
b_desc, a_desc, accum_stage_idx * UMMA_N,
|
||||
k_block_idx > 0 or k > 0, runtime_instr_desc,
|
||||
kTmemStartColOfSFB, kTmemStartColOfSFA);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
Reference in New Issue
Block a user