From 3b8aa5fd4d3c1fcc0cb6f7591522e3606ee6ab4d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 18:37:59 +0000 Subject: [PATCH] diag: stub MMA + dump descriptors for ILLEGAL_INSTRUCTION debug --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 49aaf35..0c2e1a6 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -914,11 +914,42 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2)); b_desc.lo = mma::sm100::advance_umma_desc_lo< cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2)); + + // DIAGNOSTIC: dump raw descriptor values on first MMA of first block + if (k_block_idx == 0 && k == 0 && lane_idx == 0) { + uint32_t instr_desc_raw = static_cast(instr_desc); + uint32_t runtime_instr_desc_hi = static_cast(runtime_instr_desc >> 32); + uint32_t runtime_instr_desc_lo = static_cast(runtime_instr_desc & 0xFFFFFFFF); + // Decode key bitfields from instr_desc_raw + uint32_t a_sf_id = (instr_desc_raw >> 29) & 0x3; + uint32_t b_sf_id = (instr_desc_raw >> 4) & 0x3; + uint32_t scale_fmt = (instr_desc_raw >> 23) & 0x1; + uint32_t a_format = (instr_desc_raw >> 7) & 0x7; + uint32_t b_format = (instr_desc_raw >> 10) & 0x7; + uint32_t k_size = (instr_desc_raw >> 31) & 0x1; + uint32_t m_dim = (instr_desc_raw >> 24) & 0x1F; + uint32_t n_dim = (instr_desc_raw >> 17) & 0x3F; + uint32_t a_major = (instr_desc_raw >> 15) & 0x1; + uint32_t b_major = (instr_desc_raw >> 16) & 0x1; + printf("[DIAG] a_desc: lo=0x%08x hi=0x%08x\n", a_desc.lo, a_desc.hi); + printf("[DIAG] b_desc: lo=0x%08x hi=0x%08x\n", b_desc.lo, b_desc.hi); + printf("[DIAG] instr_desc_raw: 0x%08x runtime_hi: 0x%08x runtime_lo: 0x%08x\n", + instr_desc_raw, runtime_instr_desc_hi, runtime_instr_desc_lo); + printf("[DIAG] a_format=%u b_format=%u scale_fmt=%u a_sf_id=%u b_sf_id=%u k_size=%u m_dim=%u n_dim=%u a_major=%u b_major=%u\n", + a_format, b_format, scale_fmt, a_sf_id, b_sf_id, k_size, m_dim, n_dim, a_major, b_major); + printf("[DIAG] k=%u sf_ids(a=%u,b=%u) tmem_sfa=%u tmem_sfb=%u accum_stage=%u\n", + k, k, k, kTmemStartColOfSFA, kTmemStartColOfSFB, accum_stage_idx); + } + + // DIAGNOSTIC: stub the MMA to test if fault is MMA vs upstream + // #if 0 to skip the FMA, #if 1 to enable it +#if 0 // NVFP4: use mxf4nvf4 instruction with UE4M3 scales ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma( b_desc, a_desc, accum_stage_idx * UMMA_N, k_block_idx > 0 or k > 0, runtime_instr_desc, kTmemStartColOfSFB, kTmemStartColOfSFA); +#endif } } __syncwarp();