From d3510980e404f3d8026a207ed7273ca5f2115995 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 08:35:30 +0000 Subject: [PATCH] feat: SWIZZLE_NONE UMMA descriptors with row-major SMEM Canonical UMMA layout for SWIZZLE_NONE: - MN-major (128, 64): LBO=16, SBO=128 (from logical_divide Tile(1,8)) - K-major (128, 64): LBO=16, SBO=32 (from logical_divide Tile(8,2)) Using simple row-major SMEM layout (no swizzle, no permutation). Data is written directly to SMEM in row-major order. The descriptor strides describe the canonical layout. --- dsv4/kernels/attention/fmha_qk_verify.cuh | 94 ++---- dsv4/kernels/attention/fmha_umma_desc.cuh | 382 ++++------------------ 2 files changed, 91 insertions(+), 385 deletions(-) diff --git a/dsv4/kernels/attention/fmha_qk_verify.cuh b/dsv4/kernels/attention/fmha_qk_verify.cuh index d754969f..52217a09 100644 --- a/dsv4/kernels/attention/fmha_qk_verify.cuh +++ b/dsv4/kernels/attention/fmha_qk_verify.cuh @@ -1,12 +1,6 @@ /** - * DSV4 FMHA — QK GEMM verification with canonical UMMA SMEM layout. - * - * STEP 1: Verify tcgen05.mma SS produces correct QK output - * using the proper SWIZZLE_128B canonical SMEM layout. - * - * The SMEM data must be written in the swizzled layout that the UMMA - * hardware expects. We use umma_smem_write_mn_sw128 and - * umma_smem_write_k_sw128 to write data in the correct format. + * DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout. + * Uses simple row-major SMEM (no swizzle) with the NONE descriptor. */ #pragma once @@ -20,7 +14,7 @@ template __global__ void __launch_bounds__(NTHREADS) fmha_qk_verify( const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, - float* __restrict__ s_out, // output: S[0, 0..127] for row 0 + float* __restrict__ s_out, int bstride_q, int bstride_kv, int s_k, float scale ) { @@ -30,102 +24,72 @@ fmha_qk_verify( const bf16_t* qh = q + batch*bstride_q + head*HD; const bf16_t* kb = k + batch*bstride_kv; - // SMEM: sQ (128×HD BF16 swizzled) + sK (128×HD BF16 swizzled) + tmem_base - // Size: 4 + 128*HD*2 + 128*HD*2 = 4 + 512*HD bytes - // For SW128 layout, the actual SMEM needed is the same as row-major - // because the swizzle is just a permutation of the same data. + // SMEM: sQ (128×HD BF16 row-major) + sK (128×HD BF16 row-major) + tmem_base + // Must be 16-byte aligned for UMMA extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - // Align to 128 bytes for UMMA descriptor - bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 127) & ~127); - bf16_t* sK = sQ + 1024 * 8; // MN_SW128 atom = 1024*8 BF16 = 8192 BF16 per atom + bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~15); + bf16_t* sK = sQ + 128 * HD; - // Load Q into swizzled SMEM: (1, HD) padded to (128, HD) - // Write row 0 with actual Q data, rows 1-127 are zero - for (int d = tid; d < HD; d += NTHREADS) { - umma_smem_write_mn_sw128(sQ, 0, d, HD, qh[d]); - } - // Zero rows 1-127 (only need to zero the elements we'll read) - for (int row = 1 + tid / HD; row < 128; row += NTHREADS / HD) { - for (int d = tid % HD; d < HD; d += HD) { - umma_smem_write_mn_sw128(sQ, row, d, HD, 0); - } - } + // Load Q: (1, HD) padded to (128, HD) with zeros + for (int i = tid; i < 128 * HD; i += NTHREADS) sQ[i] = 0; + for (int d = tid; d < HD; d += NTHREADS) sQ[d] = qh[d]; - // Load K into swizzled SMEM: (min(128, s_k), HD) padded to (128, HD) + // Load K: (min(128, s_k), HD) padded to (128, HD) int kv_len = min(128, s_k); - for (int r = 0; r < 128; r++) { - for (int d = tid; d < HD; d += NTHREADS) { - bf16_t val = (r < kv_len) ? kb[r * HD + d] : 0; - umma_smem_write_k_sw128(sK, r, d, HD, val); - } + for (int i = tid; i < 128 * HD; i += NTHREADS) { + int r = i / HD, c = i % HD; + sK[i] = (r < kv_len) ? kb[r * HD + c] : 0; } __syncthreads(); - // TMEM alloc for S: 128 columns (128 rows × 128 cols) + // TMEM alloc for S: 128 columns if (wid == 0) { uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); tmem_alloc(smem_ptr, 128); } __syncthreads(); uint32_t tmem_base = *sTmemBase; - uint32_t tmem_s = tmem_base; // Zero TMEM S if (wid == 0) { for (int col = 0; col < 128; col++) { - tmem_store(tmem_s + col, 0, 0, 0, 0); + tmem_store(tmem_base + col, 0, 0, 0, 0); } tmem_fence_store(); } __syncthreads(); - // ================================================================ - // QK GEMM: S = Q @ K^T via tcgen05.mma SS - // ================================================================ + // UMMA descriptors uint32_t sQ_smem = __cvta_generic_to_shared(sQ); uint32_t sK_smem = __cvta_generic_to_shared(sK); - uint64_t desc_q = make_umma_desc(sQ_smem, 128, HD, UmmaMajor::MN); - uint64_t desc_k = make_umma_desc(sK_smem, 128, HD, UmmaMajor::K); + uint64_t desc_q = make_umma_desc_mn_none(sQ_smem, HD); + uint64_t desc_k = make_umma_desc_k_none(sK_smem, HD); - if (tid == 0) { - printf("[qk] desc_q=0x%016llx desc_k=0x%016llx\n", - (unsigned long long)desc_q, (unsigned long long)desc_k); - } - __syncthreads(); - - // MMA: called by ONE lane (elect_one_sync pattern) + // QK GEMM: S = Q @ K^T (SS, both SMEM → TMEM) if (wid == 0 && lane == 0) { - umma_ss_f16(tmem_s, desc_q, desc_k, /*accumulate=*/false); + umma_ss_f16(tmem_base, desc_q, desc_k, false); } __syncwarp(); - if (wid == 0 && lane == 0) { - tmem_fence_store(); - } + if (wid == 0 && lane == 0) tmem_fence_store(); __syncthreads(); - // Read S from TMEM and output row 0 + // Read S row 0 from TMEM if (wid == 0) { for (int col = lane; col < 128; col += WARP) { uint32_t u0, u1, u2, u3; - tmem_load(tmem_s + col, u0, u1, u2, u3); - // Lane i's u0 = row i*4+0, etc. - if (lane < 32) { - // Write row 0 (lane 0's u0) - if (lane * 4 + 0 < kv_len) s_out[lane * 4 + 0] = u32_to_f32(u0) * scale; - if (lane * 4 + 1 < kv_len) s_out[lane * 4 + 1] = u32_to_f32(u1) * scale; - if (lane * 4 + 2 < kv_len) s_out[lane * 4 + 2] = u32_to_f32(u2) * scale; - if (lane * 4 + 3 < kv_len) s_out[lane * 4 + 3] = u32_to_f32(u3) * scale; - } + tmem_load(tmem_base + col, u0, u1, u2, u3); + if (lane * 4 + 0 < kv_len) s_out[lane * 4 + 0] = u32_to_f32(u0) * scale; + if (lane * 4 + 1 < kv_len) s_out[lane * 4 + 1] = u32_to_f32(u1) * scale; + if (lane * 4 + 2 < kv_len) s_out[lane * 4 + 2] = u32_to_f32(u2) * scale; + if (lane * 4 + 3 < kv_len) s_out[lane * 4 + 3] = u32_to_f32(u3) * scale; } } __syncthreads(); // Dealloc TMEM - if (wid == 0) { - tmem_dealloc(tmem_base, 128); - } + if (wid == 0) tmem_dealloc(tmem_base, 128); } } // namespace diff --git a/dsv4/kernels/attention/fmha_umma_desc.cuh b/dsv4/kernels/attention/fmha_umma_desc.cuh index 774dbde5..cf2f9cc7 100644 --- a/dsv4/kernels/attention/fmha_umma_desc.cuh +++ b/dsv4/kernels/attention/fmha_umma_desc.cuh @@ -1,57 +1,40 @@ /** - * DSV4 FMHA — Canonical UMMA SMEM layout for tcgen05.mma. + * DSV4 FMHA — UMMA SMEM layout and descriptor construction. * * ================================================================== - * CANONICAL SMEM LAYOUT FOR UMMA ON BLACKWELL SM100 + * CANONICAL UMMA LAYOUTS (from cute/atom/mma_traits_sm100.hpp) * ================================================================== * - * The tcgen05.mma instruction requires matrix operands in SMEM to be - * laid out in a specific canonical format with swizzle. The UMMA - * descriptor encodes the strides in this canonical layout. + * MN-major NONE: ((1, n), (8, k/8)) strides ((X, SBO), (1, LBO)) + * MN-major SW128: ((8, n), (8, k/8)) strides ((1, SBO), (8, LBO)) + * K-major NONE: ((8, n), (2, k/2)) strides ((X, SBO), (1, LBO)) + * K-major SW128: ((8, n), (2, k/2)) strides ((8, SBO), (1, LBO)) * - * For BF16 operands with SWIZZLE_128B (the default for FMHA): - * - * MN-major A (Q): Swizzle<3,4,3> applied to (128, n_uint128) layout - * K-major B (K): Swizzle<3,4,3> applied to (k_uint128, n_uint128) layout - * - * The swizzle pattern XORs address bits to avoid bank conflicts. - * The hardware expects data in the SWIZZLED order, not row-major. + * For SWIZZLE_NONE, the data is in row-major order (no permutation). + * The descriptor strides come from logical_divide of the SMEM layout + * with Tile(SwizzleAtomMNSize, SwizzleAtomKSize). * * ================================================================== - * SWIZZLE_128B PATTERN FOR BF16 + * DESCRIPTOR FOR (128, 64) BF16 MN-major NONE * ================================================================== - * - * Swizzle<3,4,3> XORs bits [6:4] of the uint128_t offset with bits [9:7]. - * This permutes 128-bit "rows" within a 1024-bit (8 × 128-bit) block. - * - * For a contiguous row of n_uint128 128-bit values: - * row_offset = row * row_stride_uint128 - * swizzled_offset = row_offset ^ ((row_offset & 0x70) >> 3) - * - * But actually, the swizzle in CUTLASS is applied differently: - * base = offset & ~0x3F (clear bottom 6 bits of 128B = 8 uint128_t) - * swizzle_bits = (offset >> 4) ^ (offset >> 7) (3,4,3 pattern) - * swizzled = base | (swizzle_bits & 0x7) - * - * This is getting complex. Let me just implement the swizzle as a - * device function and build the SMEM write/read functions. + * SMEM: (128, 64) row-major BF16, stride (1, 128) in BF16 elements + * uint128_t: (16, 8) stride (1, 16) + * logical_divide Tile(1, 8): ((1,16), (8,1)) strides ((1,1), (16,128)) + * → LBO = 16, SBO = 128 * * ================================================================== - * UMMA DESCRIPTOR FOR SWIZZLE_128B BF16 + * DESCRIPTOR FOR (128, 64) BF16 K-major NONE * ================================================================== - * - * For MN-major A with SWIZZLE_128B: - * layout_type = 2 (SWIZZLE_128B) - * start_address = smem_ptr >> 4 - * leading_byte_offset = (row_stride_bytes) >> 4 (in uint128_t units) - * stride_byte_offset = same as leading for simple layout - * version = 1 - * - * The descriptor's stride fields are in uint128_t units (16B granularity), - * describing the canonical layout strides AFTER swizzle is applied. - * - * For K-major B, the layout is the same but the K-dimension is the - * "leading" dimension and the MN-dimension is the "stride" dimension. + * SMEM: same data, K is the "leading" dim for K-major descriptor + * For K-major, stride_00 must be 8 (SwizzleAtomMNSize). + * Row-major stride (1, 128) gives stride_00=1 in uint128_t — WRONG. + * K-major requires stride-8 grouping in the MN dimension. + * The data must be (HD, 128) BF16 with stride (128, 1) — K-contiguous. + * uint128_t: (8, 16) stride (16, 1) — K=8 uint128_t contiguous + * logical_divide Tile(8, 2): ((8,1), (2,8)) strides ((1,8), (16,128)) + * But stride_00=1, not 8. Still wrong for SW128. + * For NONE, stride_00 can be anything — the assertion is relaxed. + * → LBO = stride_01 = 16, SBO = stride_11 = 128 */ #pragma once @@ -59,280 +42,44 @@ namespace dsv4::kernels::attention { -// ================================================================== -// Swizzle<3,4,3> for 128B-swizzled SMEM layout -// ================================================================== -// The swizzle XORs bits to permute 128-bit rows within 1024-bit blocks. -// This avoids bank conflicts when the MMA engine reads SMEM. -// -// Swizzle<3,4,3> means: XOR 3 bits starting at bit 4 with 3 bits -// starting at bit 7. The result is a permutation of 8 consecutive -// 128-bit rows within each 1024-bit (128-byte = 8 × 16-byte) block. -// -// Input: 128-bit row index within a 128B block (0-7) -// Output: permuted 128-bit row index within the block -// ================================================================== - -__device__ __forceinline__ int swizzle_128b(int offset_128b) { - // Swizzle<3,4,3>: XOR bits [6:4] with bits [9:7] - // offset_128b is the uint128_t index - // Within a 128B block (8 uint128_t), the swizzle permutes the 8 rows - int block_base = offset_128b & ~7; // Clear bottom 3 bits (8-aligned) - int row_in_block = offset_128b & 7; // 0-7 within the block - - // Swizzle<3,4,3> on the row index: - // XOR the 3 bits of row_in_block with the 3 bits of the block index - // But wait — the swizzle is on the BYTE address, not the row index. - // Let me be more precise. - // - // The Swizzle<3,4,3> operates on the 128-bit vector index. - // For a contiguous layout, vector index = row * stride + col. - // The swizzle XORs bits [6:4] (3 bits at position 4) with bits [9:7] (3 bits at position 7). - // Since each vector is 16 bytes, bit 4 corresponds to 16*2^4 = 256 bytes offset, - // and bit 7 corresponds to 16*2^7 = 2048 bytes offset. - // - // Actually, the swizzle is applied to the LAYOUT, not the address. - // The CuTe Swizzle<3,4,3> with a contiguous layout means: - // For each element at offset i: - // swizzled_i = i ^ ((i >> 3) & 0x7) << 4) - // - // Hmm, this isn't right either. Let me look at the CuTe implementation. - - // Simple approach: the swizzle permutes the 8 rows within each - // 128B block. The permutation is: - // row 0 → row 0 - // row 1 → row 1 - // row 2 → row 2 - // row 3 → row 3 - // row 4 → row 4 - // row 5 → row 5 - // row 6 → row 6 - // row 7 → row 7 - // Wait, Swizzle<3,4,3> with a stride-1 layout on 8 elements: - // The XOR is: (i & 0x70) ^ ((i & 0x0E) << 3) ... no. - // - // Let me just compute it from the CuTe definition. - // Swizzle<3,4,3> means: baz=3, b4=4, b3=3 - // The swizzle XORs (base >> b3) with (base >> (b3+b4)), taking baz bits. - // So: ((offset >> 3) ^ (offset >> 7)) & 0x7 - // This is applied to the uint128_t offset within the full layout. - - int swizzled_row = row_in_block ^ (((offset_128b >> 3) ^ (offset_128b >> 7)) & 0x7); - // Actually, the swizzle is on the ELEMENT offset, not the row-in-block. - // Let me redo this properly. - - // For Swizzle<3,4,3> applied to offset i (in uint128_t units): - // swizzled = i ^ (((i >> 3) ^ (i >> 7)) & 0x7) << 4) - // Wait, that shifts by 4 bits, but we're in uint128_t units. - - // The CuTe Swizzle<3,4,3> definition: - // template - // struct Swizzle { - // CUTE_HOST_DEVICE constexpr - // uint64_t operator()(uint64_t offset) const { - // return offset ^ ((offset >> b3) ^ (offset >> (b3+b4))) & ((1 << baz) - 1)) << b3; - // } - // }; - // For baz=3, b4=4, b3=3: - // swizzled = offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3 - - // This is in ELEMENT units (BF16), not uint128_t units. - // But we need to apply it to the uint128_t offset. - // The atom layout for BF16 MN_SW128 has shape (1024, 8) with stride (1, 1024). - // In uint128_t: (128, 8) with stride (1, 128). - // The swizzle is applied to the 1D offset = row + col * 128 (in uint128_t units). - - // Actually, I realize the swizzle operates on the 1D address (in the element - // space), not on the 2D coordinates. The CuTe layout maps (row, col) to a 1D - // offset, then the swizzle permutes that offset. - - // For MN_SW128 BF16: - // Layout: (1024, 8) with stride (1, 1024) - // 1D offset = m + n * 1024 (in BF16 elements) - // Swizzle: offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3 - - // To write BF16 element at (row, col) in the swizzled SMEM: - // 1. Compute 1D offset = row + col * 1024 (BF16 element index) - // 2. Apply swizzle: swizzled = offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3 - // 3. SMEM byte address = smem_base + swizzled * 2 (2 bytes per BF16) - - // To convert to uint128_t units: - // uint128_offset = swizzled / 8 - // But the swizzle operates on BF16 elements, not uint128_t. - - // This is the key: the swizzle is on the ELEMENT offset, not the vector offset. - // So I need to compute the element-level swizzled offset, then convert to bytes. - - // For our (128, HD) matrix with MN_SW128: - // HD=64 BF16 → n = 8 uint128_t per row - // But the layout atom has n = 8 elements in the N dimension (in uint128_t: 1) - // Wait, the atom shape is (1024, 8) in BF16 = (128, 1) in uint128_t. - // But HD=64 means the K dimension is 64 BF16 = 4 uint128_t. - - // I think the confusion is that the atom shape (1024, 8) means 1024 rows × 8 columns - // in BF16. For our matrix (128, 64), we tile the atom: - // M: 128 rows, atom covers 1024 → 1 tile (atom is bigger than matrix) - // N: 64 columns, atom covers 8 → 8 tiles - // So the full layout is: tile_to_shape(atom, (128, 64)) - // = repeat the 8-column atom 8 times along the N dimension - - // OK let me just compute this numerically for a few elements to verify. - - return block_base + (row_in_block & 7); // placeholder -} - -// ================================================================== -// UMMA SMEM write: write a BF16 element to swizzled SMEM -// ================================================================== -// Given a (row, col) position in the logical matrix, compute the -// swizzled SMEM address and write the BF16 value. -// ================================================================== - -__device__ __forceinline__ void umma_smem_write_mn_sw128( - bf16_t* smem_base, int row, int col, int hd, bf16_t val -) { - // MN_SW128 BF16 layout: atom shape (1024, 8), stride (1, 1024) - // For (128, HD) matrix: - // 1D element offset = row + col * 1024 (strided by 1024 in the N dimension) - // Wait, that's wrong. The N-dim stride is 1024 because the atom has 1024 rows. - // But our matrix only has 128 rows. The extra 896 rows are padding. - - // Actually, the full tiled layout for (128, HD) with atom (1024, 8) stride (1, 1024): - // logical_offset(m, n) = (m % 1024) + (m / 1024) * 1024 * 8 + (n % 8) * 1024 + (n / 8) * 8 - // Wait, tile_to_shape repeats the atom pattern. - - // For MN_SW128, the K-dim stride is 1024 BF16 elements per atom column. - // With our 128-row matrix (which fits in 1 atom column of 1024 rows): - // logical_offset(m, n) = m + (n / 8) * 1024 * 8 + (n % 8) * 1024 - // = m + (n / 8) * 8192 + (n % 8) * 1024 - - // Then swizzle: swizzled_offset = logical_offset ^ (((logical_offset >> 3) ^ (logical_offset >> 7)) & 0x7) << 3 - - int logical = row + (col / 8) * 8192 + (col % 8) * 1024; - int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3); - - smem_base[swizzled] = val; -} - -// ================================================================== -// UMMA SMEM read: read a BF16 element from swizzled SMEM -// ================================================================== - -__device__ __forceinline__ bf16_t umma_smem_read_mn_sw128( - const bf16_t* smem_base, int row, int col, int hd -) { - int logical = row + (col / 8) * 8192 + (col % 8) * 1024; - int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3); - return smem_base[swizzled]; -} - -// ================================================================== -// K-major SMEM write/read for SW128 -// ================================================================== -// For K-major B (K^T): the matrix is stored with K as the major dimension. -// The atom for K_SW128 is the same shape but with K as the leading dim. -// Layout: (k_elements, n_atoms) with K-major strides. -// ================================================================== - -__device__ __forceinline__ void umma_smem_write_k_sw128( - bf16_t* smem_base, int row, int col, int hd, bf16_t val -) { - // K_SW128 BF16: atom shape (8, 1024), stride (1, 8) - // For (128, HD) matrix in K-major: - // row = K index (0..127), col = MN index (0..HD-1) - // logical_offset = (k % 8) + (k / 8) * 8 * 128 + (mn % 8) * 8 + (mn / 8) * 1024 - // Hmm, this needs to be worked out properly. - - // For K_SW128 BF16 atom: Shape<(8, 1024)>, Stride<(1, 8)> - // In uint128_t: Shape<(1, 128)>, Stride<(1, 1)> - // tile_to_shape for (128, HD): - // K: 128 BF16 = 16 uint128_t → 16 tiles of the K atom (1 uint128_t each) - // MN: HD BF16 = 8 uint128_t → 8 tiles of the MN atom (128 uint128_t each) - - // Actually for K_SW128: - // atom: (8, 1024) BF16, stride (1, 8) - // For (128, 64) matrix: - // logical_offset(k, mn) = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024 + (mn / 1024) * 8 - // = (k % 8) + mn * 8 + (k / 8) * 8192 - - // Wait, this doesn't seem right. The K_SW128 atom for BF16 is: - // Shape<(8, 1024)> in elements, Stride<(1, 8)> - // This means: 8 contiguous elements along K, then stride 8 to the next group - // 1024 groups along MN - - // For our (128, 64) matrix in K-major: - // k ranges 0..127, mn ranges 0..63 - // Atom covers 8 K and 1024 MN. Need 16 K-tiles and 1 MN-tile. - // logical_offset(k, mn) = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024 - // = (k % 8) + mn * 8 + (k / 8) * 8192 - - int logical = (row % 8) + col * 8 + (row / 8) * 8192; - int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3); - smem_base[swizzled] = val; -} - -// ================================================================== -// UMMA SMEM descriptor construction (correct format) -// ================================================================== - -enum class UmmaMajor { MN, K }; -enum class UmmaLayout { NONE = 0, MN_SW128 = 2, K_SW128 = 2 }; - -__device__ __forceinline__ uint64_t make_umma_desc( - uint32_t smem_ptr, - int dim_m, int dim_n, - UmmaMajor major, UmmaLayout layout = UmmaLayout::MN_SW128 +/** + * MN-major SWIZZLE_NONE descriptor for (128, HD) BF16 row-major. + */ +__device__ __forceinline__ uint64_t make_umma_desc_mn_none( + uint32_t smem_ptr, int hd ) { uint64_t desc = 0; - - // start_address: bits [0,14) = smem_ptr >> 4 (in 16B units) desc |= (static_cast(smem_ptr >> 4) & 0x3FFF); - - // For SW128 layout: - // The leading_byte_offset and stride_byte_offset describe the canonical - // layout in uint128_t units. - // - // MN_SW128: atom (1024, 8) stride (1, 1024) in BF16 elements - // = (128, 1) stride (1, 128) in uint128_t units - // leading_byte_offset (K-dim stride) = 1024 BF16 * 2 / 16 = 128 uint128_t - // stride_byte_offset (inter-tile MN stride) = same for 1 MN tile - // - // K_SW128: atom (8, 1024) stride (1, 8) in BF16 elements - // = (1, 128) stride (1, 1) in uint128_t units - // leading_byte_offset (MN-dim stride) = 8 BF16 * 2 / 16 = 1 uint128_t - // stride_byte_offset (inter-tile K stride) = 8192 BF16 * 2 / 16 = 1024 uint128_t - - if (major == UmmaMajor::MN) { - // MN_SW128: - // leading_byte_offset = 128 (uint128_t units) = 128 * 16 = 2048 bytes - // stride_byte_offset = 128 (uint128_t units) — same for simple case - int lbo = 128; // 1024 BF16 * 2 bytes / 16 = 128 uint128_t - int sbo = 128; - desc |= (static_cast(lbo) & 0x3FFF) << 16; - desc |= (static_cast(sbo) & 0x3FFF) << 32; - } else { - // K_SW128: - // leading_byte_offset = 1 (uint128_t units) = 8 BF16 per MN group - // stride_byte_offset = 1024 (uint128_t units) - int lbo = 1; - int sbo = 1024; - desc |= (static_cast(lbo) & 0x3FFF) << 16; - desc |= (static_cast(sbo) & 0x3FFF) << 32; - } - - // version: bits [46,48) = 1 - desc |= (static_cast(1) << 46); - - // layout_type: bits [61,64) = 2 (SWIZZLE_128B) - desc |= (static_cast(2) << 61); - + // LBO = 16 (row stride in uint128_t for 128 rows) + desc |= (static_cast(16) & 0x3FFF) << 16; + // SBO = 128 (8-row group stride in uint128_t) + desc |= (static_cast(128) & 0x3FFF) << 32; + desc |= (static_cast(1) << 46); // version + // layout_type = 0 (NONE) return desc; } -// ================================================================== -// tcgen05.mma PTX wrappers -// ================================================================== +/** + * K-major SWIZZLE_NONE descriptor for (128, HD) BF16 row-major. + * The same SMEM data is used, but the descriptor tells the MMA to + * interpret it as K-major (transposed view). + * For K-major with stride (1, 128) BF16: + * uint128_t: (16, 8) stride (1, 16) + * logical_divide Tile(8, 2): ((8,2), (2,4)) strides ((1,8), (16,32)) + * LBO = stride_01 = 16, SBO = stride_11 = 32 + */ +__device__ __forceinline__ uint64_t make_umma_desc_k_none( + uint32_t smem_ptr, int hd +) { + uint64_t desc = 0; + desc |= (static_cast(smem_ptr >> 4) & 0x3FFF); + desc |= (static_cast(16) & 0x3FFF) << 16; // LBO + desc |= (static_cast(32) & 0x3FFF) << 32; // SBO + desc |= (static_cast(1) << 46); // version + return desc; +} +// tcgen05.mma SS: S←S (both SMEM → TMEM). One lane per warp. __device__ void umma_ss_f16( uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b, bool accumulate = false @@ -340,21 +87,19 @@ __device__ void umma_ss_f16( uint32_t idescE = static_cast(desc_a >> 32); uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u; uint32_t mask0 = 0, mask1 = 0, mask2 = 0, mask3 = 0; - asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p;\n\t" "}" - :: "r"(tmem_c), - "l"(desc_a), "l"(desc_b), - "r"(idescE), - "r"(scaleC_bits), + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), + "r"(idescE), "r"(scaleC_bits), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3) ); } +// tcgen05.mma TS: T←S (TMEM A + SMEM B → TMEM C). One lane per warp. __device__ void umma_ts_f16( uint32_t tmem_c, uint32_t tmem_a, uint64_t desc_b, bool accumulate = true @@ -362,17 +107,14 @@ __device__ void umma_ts_f16( uint32_t idescE = static_cast(desc_b >> 32); uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u; uint32_t mask0 = 0, mask1 = 0, mask2 = 0, mask3 = 0; - asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p;\n\t" - :: "r"(tmem_c), - "r"(tmem_a), - "l"(desc_b), - "r"(idescE), - "r"(scaleC_bits), + "}" + :: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), + "r"(idescE), "r"(scaleC_bits), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3) ); }