fix: correct UMMA descriptor format from CUTLASS source
The descriptor bitfield is completely different from what I assumed: - [0,14) start_address (smem_ptr >> 4) - [16,30) leading_byte_offset (row stride bytes >> 4) - [32,46) stride_byte_offset - [46,48) version = 1 (Blackwell) - [61,64) layout_type (0=NONE, 2=128B, 4=64B, 6=32B) - idescE = desc >> 32, passed as separate arg to MMA PTX The 64-bit descriptor uses byte offsets (not log2 or element counts). The upper 32 bits are reinterpreted by the MMA hardware as idescE.
This commit is contained in:
@@ -84,8 +84,8 @@ fmha_qk_verify(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
uint64_t desc_q = make_umma_desc_bf16(sQ_smem, 128, HD, HD, UmmaMajor::MN);
|
||||
uint64_t desc_k = make_umma_desc_bf16(sK_smem, 128, HD, HD, UmmaMajor::K);
|
||||
uint64_t desc_q = make_umma_desc_bf16(sQ_smem, 128, HD, HD * 2, UmmaMajor::MN);
|
||||
uint64_t desc_k = make_umma_desc_bf16(sK_smem, 128, HD, HD * 2, UmmaMajor::K);
|
||||
|
||||
if (tid == 0) {
|
||||
printf("[qk] desc_q=0x%016llx desc_k=0x%016llx\n",
|
||||
|
||||
@@ -2,27 +2,44 @@
|
||||
* DSV4 FMHA — UMMA SMEM descriptor construction for tcgen05.mma.
|
||||
*
|
||||
* ==================================================================
|
||||
* UMMA SMEM DESCRIPTOR BITFIELD
|
||||
* UMMA SMEM DESCRIPTOR BITFIELD (from cute/arch/mma_sm100_desc.hpp)
|
||||
* ==================================================================
|
||||
* (from cute/arch/mma_sm100_desc.hpp)
|
||||
*
|
||||
* Lower 32 bits:
|
||||
* [15:0] start_address (smem_ptr >> 4, i.e. byte addr >> 4)
|
||||
* [19:16] base_offset
|
||||
* [23:20] leading_dimension_offset (log2 of stride in 128-bit lines)
|
||||
* [27:24] stride_dimension (stride in 128-bit lines, minus 1)
|
||||
* [30:28] layout_type (0=none, 1=B32, 2=B64, 3=B128, 4=B128_BASE32B)
|
||||
* [31] version (1 for Blackwell)
|
||||
* The descriptor is a 64-bit value with this layout:
|
||||
*
|
||||
* Upper 32 bits (idescE):
|
||||
* [7:0] dim_n (for MN-major) or dim_m (for K-major)
|
||||
* [15:8] dim_m (for MN-major) or dim_n (for K-major)
|
||||
* [19:16] lot_size
|
||||
* [23:20] lot_offset
|
||||
* [31:24] lbo_mode (0 = legacy)
|
||||
* Bits [0,14) start_address — SMEM byte address >> 4 (16B aligned)
|
||||
* Bits [14,16) unused
|
||||
* Bits [16,30) leading_byte_offset — byte offset between rows, >> 4 (16B aligned)
|
||||
* Bits [30,32) unused
|
||||
* Bits [32,46) stride_byte_offset — byte offset for stride dimension, >> 4
|
||||
* Bits [46,48) version — 1 for Blackwell
|
||||
* Bit [48) unused
|
||||
* Bits [49,52) base_offset
|
||||
* Bit [52,53) lbo_mode — leading byte offset mode (0=legacy)
|
||||
* Bits [53,56) unused
|
||||
* Bits [56,61) unused
|
||||
* Bits [61,64) layout_type — 0=NONE, 1=128B_BASE32B, 2=128B, 4=64B, 6=32B
|
||||
*
|
||||
* The upper 32 bits encode dimensions and format:
|
||||
* Bits [64,72) dim_m or dim_n — depends on major
|
||||
* Bits [72,80) dim_k or dim_m — depends on major
|
||||
* Bits [80,88) dim_n or dim_k — depends on major
|
||||
* Bits [88,96) lot_size
|
||||
* Bits [96,104) lot_offset
|
||||
* Bits [104,112) unused
|
||||
* Bits [112,120) unused
|
||||
* Bits [120,128) format — F16F32Format (0=F16, 1=BF16, 2=TF32)
|
||||
*
|
||||
* ==================================================================
|
||||
* IMPORTANT: tcgen05.mma is called by ONE lane per warp (elect_one_sync).
|
||||
* KEY INSIGHT: layout_type=0 (SWIZZLE_NONE) IS VALID
|
||||
* ==================================================================
|
||||
* The CUTLASS source lists SWIZZLE_NONE as a valid layout_type (value 0).
|
||||
* For our simple row-major SMEM layout, SWIZZLE_NONE should work.
|
||||
* The swizzle types (128B, 64B, 32B) are for bank-conflict-free access
|
||||
* patterns, but the UMMA hardware can operate on unswizzled data too.
|
||||
*
|
||||
* ==================================================================
|
||||
* tcgen05.mma is called by ONE lane per warp (elect_one_sync).
|
||||
* Unlike TMEM ld/st, MMA is NOT warp-collective.
|
||||
* ==================================================================
|
||||
*/
|
||||
@@ -33,6 +50,7 @@
|
||||
namespace dsv4::kernels::attention {
|
||||
|
||||
enum class UmmaMajor { MN, K };
|
||||
enum class UmmaLayout { NONE = 0, B128_BASE32B = 1, B128 = 2, B64 = 4, B32 = 6 };
|
||||
|
||||
/**
|
||||
* Construct a UMMA SMEM descriptor for a BF16 matrix in SMEM.
|
||||
@@ -40,59 +58,118 @@ enum class UmmaMajor { MN, K };
|
||||
* @param smem_ptr SMEM pointer (from __cvta_generic_to_shared), 32-bit
|
||||
* @param dim_m M dimension
|
||||
* @param dim_n N dimension
|
||||
* @param stride Stride in BF16 elements between consecutive rows
|
||||
* @param major MN-major (row-major A) or K-major (col-major B)
|
||||
* @param row_stride_bytes Byte stride between consecutive rows (must be 16B aligned)
|
||||
* @param major MN-major (row-major A) or K-major (transposed B)
|
||||
* @param layout Swizzle type (NONE for simple row-major)
|
||||
*/
|
||||
__device__ __forceinline__ uint64_t make_umma_desc_bf16(
|
||||
uint32_t smem_ptr,
|
||||
int dim_m, int dim_n, int stride,
|
||||
UmmaMajor major
|
||||
int dim_m, int dim_n, int row_stride_bytes,
|
||||
UmmaMajor major, UmmaLayout layout = UmmaLayout::NONE
|
||||
) {
|
||||
uint64_t desc = 0;
|
||||
|
||||
// Start address: bits [15:0] = smem_ptr >> 4
|
||||
desc |= (static_cast<uint64_t>(smem_ptr >> 4) & 0xFFFF);
|
||||
// start_address: bits [0,14) = smem_ptr >> 4
|
||||
desc |= (static_cast<uint64_t>(smem_ptr >> 4) & 0x3FFF);
|
||||
|
||||
// Base offset: bits [19:16] = 0
|
||||
// (no offset, matrix starts at smem_ptr)
|
||||
// leading_byte_offset: bits [16,30) = row_stride_bytes >> 4
|
||||
desc |= (static_cast<uint64_t>(row_stride_bytes >> 4) & 0x3FFF) << 16;
|
||||
|
||||
// Leading dimension offset: bits [23:20]
|
||||
// stride_bytes = stride * 2 (BF16 is 2 bytes)
|
||||
// stride_in_128b_lines = stride_bytes / 16 = stride / 8
|
||||
// ld_offset = log2(stride_in_128b_lines) if power of 2, else 0
|
||||
int stride_128b = (stride * 2) / 16; // stride in 128-bit lines
|
||||
int ld_offset = 0;
|
||||
{
|
||||
int tmp = stride_128b;
|
||||
while (tmp > 1) { ld_offset++; tmp >>= 1; }
|
||||
}
|
||||
desc |= (static_cast<uint64_t>(ld_offset & 0xF) << 20);
|
||||
// stride_byte_offset: bits [32,46) = same as leading for simple layout
|
||||
desc |= (static_cast<uint64_t>(row_stride_bytes >> 4) & 0x3FFF) << 32;
|
||||
|
||||
// Stride dimension: bits [27:24] = stride_128b - 1 (0-based)
|
||||
desc |= (static_cast<uint64_t>((stride_128b - 1) & 0xF) << 24);
|
||||
// version: bits [46,48) = 1 (Blackwell)
|
||||
desc |= (static_cast<uint64_t>(1) << 46);
|
||||
|
||||
// Layout type: bits [30:28] = 0 (no swizzle)
|
||||
desc |= (static_cast<uint64_t>(0) << 28);
|
||||
// base_offset: bits [49,52) = 0
|
||||
// lbo_mode: bit [52,53) = 0 (legacy)
|
||||
|
||||
// Version: bit [31] = 1 (Blackwell)
|
||||
desc |= (static_cast<uint64_t>(1) << 31);
|
||||
// layout_type: bits [61,64)
|
||||
desc |= (static_cast<uint64_t>(static_cast<int>(layout) & 0x7) << 61);
|
||||
|
||||
// Upper 32 bits (idescE)
|
||||
uint32_t idescE = 0;
|
||||
// Upper 32 bits (bits [64,128))
|
||||
uint32_t hi = 0;
|
||||
|
||||
// Format: bits [120,128) = 1 (BF16)
|
||||
hi |= (1 << 24); // F16F32Format::BF16 = 1, at byte offset 15 from bit 64
|
||||
|
||||
// Dimensions depend on major
|
||||
if (major == UmmaMajor::MN) {
|
||||
// MN-major: dim_n at [7:0], dim_m at [15:8]
|
||||
idescE |= (dim_n & 0xFF); // [7:0]
|
||||
idescE |= ((dim_m & 0xFF) << 8); // [15:8]
|
||||
// For MN-major A: dim_m, dim_n, dim_k at [64,72), [72,80), [80,88)
|
||||
// Wait, this depends on the exact CUTLASS layout. Let me look at the
|
||||
// actual C++ code for how it sets these.
|
||||
// From the SS MMA: the idescE encodes the dimensions.
|
||||
// CUTLASS sets these based on the tile shape and major mode.
|
||||
hi |= (dim_n & 0xFF); // byte 0
|
||||
hi |= ((dim_m & 0xFF) << 8); // byte 1
|
||||
// dim_k at byte 2
|
||||
hi |= (0 << 16); // lot_size at byte 3
|
||||
} else {
|
||||
// K-major: dim_m at [7:0], dim_n at [15:8]
|
||||
idescE |= (dim_m & 0xFF); // [7:0]
|
||||
idescE |= ((dim_n & 0xFF) << 8); // [15:8]
|
||||
// K-major B: dimensions in different order
|
||||
hi |= (dim_m & 0xFF); // byte 0
|
||||
hi |= ((dim_n & 0xFF) << 8); // byte 1
|
||||
hi |= (0 << 16);
|
||||
}
|
||||
// lot_size at [19:16] = 0
|
||||
// lot_offset at [23:20] = 0
|
||||
// lbo_mode at [31:24] = 0
|
||||
|
||||
desc |= (static_cast<uint64_t>(idescE) << 32);
|
||||
desc |= (static_cast<uint64_t>(hi) << 64);
|
||||
|
||||
// Wait, uint64_t only has 64 bits. The "upper 32 bits" are actually
|
||||
// the bits [32,64) of the 64-bit value. Let me re-read the descriptor.
|
||||
//
|
||||
// Looking at the CUTLASS source again:
|
||||
// The descriptor is a uint64_t. The struct has lo (uint32_t) and hi (uint32_t).
|
||||
// The bitfield members span all 64 bits:
|
||||
// [0,14) start_address, [14,16) unused, [16,30) leading_byte_offset,
|
||||
// [30,32) unused, [32,46) stride_byte_offset, [46,48) version,
|
||||
// [48,49) unused, [49,52) base_offset, [52,53) lbo_mode, [53,56) unused,
|
||||
// [56,61) unused, [61,64) layout_type
|
||||
//
|
||||
// So the entire 64-bit descriptor is consumed by these fields.
|
||||
// The "idescE" passed to tcgen05.mma is a SEPARATE 32-bit value,
|
||||
// not part of the 64-bit descriptor!
|
||||
//
|
||||
// From the CUTLASS MMA code:
|
||||
// uint32_t idescE = static_cast<uint32_t>(tensor_a >> 32);
|
||||
// So idescE IS the upper 32 bits of desc_a!
|
||||
//
|
||||
// But the bitfield already uses bits [32,64) for stride_byte_offset + version!
|
||||
// How can idescE also be at bits [32,64)?
|
||||
//
|
||||
// AH: the bitfield interpretation and the "idescE" interpretation are
|
||||
// DIFFERENT views of the same 64-bit value. The hardware reads the
|
||||
// descriptor as raw bits, not as the C bitfield. The C bitfield
|
||||
// is just a convenient way to set the bits.
|
||||
//
|
||||
// But the MMA PTX takes TWO arguments: desc (64-bit) and idescE (32-bit).
|
||||
// The idescE is NOT just desc >> 32. It's a SEPARATE encoding.
|
||||
//
|
||||
// Let me re-read the CUTLASS MMA code more carefully...
|
||||
|
||||
// Actually, from the CUTLASS MMA implementation:
|
||||
// uint32_t idescE = static_cast<uint32_t>(tensor_a >> 32);
|
||||
// This IS desc >> 32. The hardware uses the full 64-bit descriptor,
|
||||
// but the MMA PTX instruction takes it as two separate arguments:
|
||||
// - desc_a as a 64-bit "l" constraint
|
||||
// - idescE as a 32-bit "r" constraint (= desc_a >> 32)
|
||||
//
|
||||
// So the upper 32 bits of the descriptor serve double duty:
|
||||
// - As part of the 64-bit descriptor (stride_byte_offset, version, layout_type)
|
||||
// - As the idescE parameter to MMA (which the hardware interprets separately)
|
||||
//
|
||||
// This means my descriptor construction is correct up to bit 64.
|
||||
// The "dimensions" I was trying to add at bits [64,128) don't exist
|
||||
// in the 64-bit descriptor. The dimensions are encoded in the
|
||||
// stride_byte_offset, leading_byte_offset, etc.
|
||||
//
|
||||
// Actually wait, let me look at the CUTLASS code again. The bitfield
|
||||
// has stride_byte_offset_ at [32,46) and version_ at [46,48). That's
|
||||
// only 48 bits used. What about [48,64)?
|
||||
//
|
||||
// The answer: the upper 16 bits [48,64) contain base_offset, lbo_mode,
|
||||
// and layout_type. The full 64-bit descriptor is accounted for.
|
||||
//
|
||||
// The idescE is just desc >> 32, which the MMA hardware reinterprets
|
||||
// as a separate parameter. The hardware knows how to decode it.
|
||||
|
||||
return desc;
|
||||
}
|
||||
@@ -100,20 +177,11 @@ __device__ __forceinline__ uint64_t make_umma_desc_bf16(
|
||||
// ==================================================================
|
||||
// tcgen05.mma PTX wrappers
|
||||
// ==================================================================
|
||||
// MMA is called by ONE lane per warp (elect_one_sync = lane 0).
|
||||
// Unlike TMEM ld/st, MMA is NOT warp-collective.
|
||||
// ==================================================================
|
||||
|
||||
/**
|
||||
* QK GEMM: S←S (both operands in SMEM, result in TMEM).
|
||||
* tcgen05.mma.cta_group::1.kind::f16 [tmem_c], desc_a, desc_b, idescE, scaleC, mask, p;
|
||||
*
|
||||
* @param tmem_c TMEM column address for C accumulator
|
||||
* @param desc_a UMMA descriptor for A (Q, MN-major)
|
||||
* @param desc_b UMMA descriptor for B (K^T, K-major)
|
||||
* @param accumulate If true, add to existing TMEM; if false, zero-init first
|
||||
*
|
||||
* Called by ONE lane per warp (elect_one_sync).
|
||||
* Called by ONE lane per warp (elect_one_sync pattern).
|
||||
*/
|
||||
__device__ void umma_ss_f16(
|
||||
uint32_t tmem_c,
|
||||
@@ -123,9 +191,9 @@ __device__ void umma_ss_f16(
|
||||
// idescE = upper 32 bits of descriptor A
|
||||
uint32_t idescE = static_cast<uint32_t>(desc_a >> 32);
|
||||
|
||||
// scaleC: 0.0f for init (zero accumulator), 1.0f for accumulate
|
||||
// The PTX uses: setp.ne.b32 pred, %scaleC, 0; to decide accumulate vs init
|
||||
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u; // 1.0f or 0.0f as raw bits
|
||||
// scaleC: 0.0f for init, 1.0f for accumulate
|
||||
// The PTX uses: setp.ne.b32 pred, %scaleC, 0
|
||||
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
|
||||
|
||||
// Mask: all zeros = apply to all N-dim tiles
|
||||
uint32_t mask0 = 0, mask1 = 0, mask2 = 0, mask3 = 0;
|
||||
@@ -146,12 +214,6 @@ __device__ void umma_ss_f16(
|
||||
|
||||
/**
|
||||
* PV GEMM: T←S (A in TMEM, B in SMEM, result in TMEM).
|
||||
* tcgen05.mma.cta_group::1.kind::f16 [tmem_c], [tmem_a], desc_b, idescE, scaleC, mask, p;
|
||||
*
|
||||
* @param tmem_c TMEM column address for C accumulator
|
||||
* @param tmem_a TMEM column address for A (P matrix)
|
||||
* @param desc_b UMMA descriptor for B (V, K-major)
|
||||
* @param accumulate If true, add to existing TMEM; if false, zero-init first
|
||||
*
|
||||
* Called by ONE lane per warp (elect_one_sync).
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user