feat: SWIZZLE_NONE UMMA descriptors with row-major SMEM

Canonical UMMA layout for SWIZZLE_NONE:
- MN-major (128, 64): LBO=16, SBO=128 (from logical_divide Tile(1,8))
- K-major (128, 64): LBO=16, SBO=32 (from logical_divide Tile(8,2))

Using simple row-major SMEM layout (no swizzle, no permutation).
Data is written directly to SMEM in row-major order.
The descriptor strides describe the canonical layout.
This commit is contained in:
2026-05-28 08:35:30 +00:00
parent 8c67c31497
commit d3510980e4
2 changed files with 91 additions and 385 deletions

View File

@@ -1,12 +1,6 @@
/**
* DSV4 FMHA — QK GEMM verification with canonical UMMA SMEM layout.
*
* STEP 1: Verify tcgen05.mma SS produces correct QK output
* using the proper SWIZZLE_128B canonical SMEM layout.
*
* The SMEM data must be written in the swizzled layout that the UMMA
* hardware expects. We use umma_smem_write_mn_sw128 and
* umma_smem_write_k_sw128 to write data in the correct format.
* DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout.
* Uses simple row-major SMEM (no swizzle) with the NONE descriptor.
*/
#pragma once
@@ -20,7 +14,7 @@ template<int HD>
__global__ void __launch_bounds__(NTHREADS)
fmha_qk_verify(
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
float* __restrict__ s_out, // output: S[0, 0..127] for row 0
float* __restrict__ s_out,
int bstride_q, int bstride_kv,
int s_k, float scale
) {
@@ -30,102 +24,72 @@ fmha_qk_verify(
const bf16_t* qh = q + batch*bstride_q + head*HD;
const bf16_t* kb = k + batch*bstride_kv;
// SMEM: sQ (128×HD BF16 swizzled) + sK (128×HD BF16 swizzled) + tmem_base
// Size: 4 + 128*HD*2 + 128*HD*2 = 4 + 512*HD bytes
// For SW128 layout, the actual SMEM needed is the same as row-major
// because the swizzle is just a permutation of the same data.
// SMEM: sQ (128×HD BF16 row-major) + sK (128×HD BF16 row-major) + tmem_base
// Must be 16-byte aligned for UMMA
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
// Align to 128 bytes for UMMA descriptor
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 127) & ~127);
bf16_t* sK = sQ + 1024 * 8; // MN_SW128 atom = 1024*8 BF16 = 8192 BF16 per atom
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~15);
bf16_t* sK = sQ + 128 * HD;
// Load Q into swizzled SMEM: (1, HD) padded to (128, HD)
// Write row 0 with actual Q data, rows 1-127 are zero
for (int d = tid; d < HD; d += NTHREADS) {
umma_smem_write_mn_sw128(sQ, 0, d, HD, qh[d]);
}
// Zero rows 1-127 (only need to zero the elements we'll read)
for (int row = 1 + tid / HD; row < 128; row += NTHREADS / HD) {
for (int d = tid % HD; d < HD; d += HD) {
umma_smem_write_mn_sw128(sQ, row, d, HD, 0);
}
}
// Load Q: (1, HD) padded to (128, HD) with zeros
for (int i = tid; i < 128 * HD; i += NTHREADS) sQ[i] = 0;
for (int d = tid; d < HD; d += NTHREADS) sQ[d] = qh[d];
// Load K into swizzled SMEM: (min(128, s_k), HD) padded to (128, HD)
// Load K: (min(128, s_k), HD) padded to (128, HD)
int kv_len = min(128, s_k);
for (int r = 0; r < 128; r++) {
for (int d = tid; d < HD; d += NTHREADS) {
bf16_t val = (r < kv_len) ? kb[r * HD + d] : 0;
umma_smem_write_k_sw128(sK, r, d, HD, val);
}
for (int i = tid; i < 128 * HD; i += NTHREADS) {
int r = i / HD, c = i % HD;
sK[i] = (r < kv_len) ? kb[r * HD + c] : 0;
}
__syncthreads();
// TMEM alloc for S: 128 columns (128 rows × 128 cols)
// TMEM alloc for S: 128 columns
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(smem_ptr, 128);
}
__syncthreads();
uint32_t tmem_base = *sTmemBase;
uint32_t tmem_s = tmem_base;
// Zero TMEM S
if (wid == 0) {
for (int col = 0; col < 128; col++) {
tmem_store(tmem_s + col, 0, 0, 0, 0);
tmem_store(tmem_base + col, 0, 0, 0, 0);
}
tmem_fence_store();
}
__syncthreads();
// ================================================================
// QK GEMM: S = Q @ K^T via tcgen05.mma SS
// ================================================================
// UMMA descriptors
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
uint64_t desc_q = make_umma_desc(sQ_smem, 128, HD, UmmaMajor::MN);
uint64_t desc_k = make_umma_desc(sK_smem, 128, HD, UmmaMajor::K);
uint64_t desc_q = make_umma_desc_mn_none(sQ_smem, HD);
uint64_t desc_k = make_umma_desc_k_none(sK_smem, HD);
if (tid == 0) {
printf("[qk] desc_q=0x%016llx desc_k=0x%016llx\n",
(unsigned long long)desc_q, (unsigned long long)desc_k);
}
__syncthreads();
// MMA: called by ONE lane (elect_one_sync pattern)
// QK GEMM: S = Q @ K^T (SS, both SMEM → TMEM)
if (wid == 0 && lane == 0) {
umma_ss_f16(tmem_s, desc_q, desc_k, /*accumulate=*/false);
umma_ss_f16(tmem_base, desc_q, desc_k, false);
}
__syncwarp();
if (wid == 0 && lane == 0) {
tmem_fence_store();
}
if (wid == 0 && lane == 0) tmem_fence_store();
__syncthreads();
// Read S from TMEM and output row 0
// Read S row 0 from TMEM
if (wid == 0) {
for (int col = lane; col < 128; col += WARP) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_s + col, u0, u1, u2, u3);
// Lane i's u0 = row i*4+0, etc.
if (lane < 32) {
// Write row 0 (lane 0's u0)
if (lane * 4 + 0 < kv_len) s_out[lane * 4 + 0] = u32_to_f32(u0) * scale;
if (lane * 4 + 1 < kv_len) s_out[lane * 4 + 1] = u32_to_f32(u1) * scale;
if (lane * 4 + 2 < kv_len) s_out[lane * 4 + 2] = u32_to_f32(u2) * scale;
if (lane * 4 + 3 < kv_len) s_out[lane * 4 + 3] = u32_to_f32(u3) * scale;
}
tmem_load(tmem_base + col, u0, u1, u2, u3);
if (lane * 4 + 0 < kv_len) s_out[lane * 4 + 0] = u32_to_f32(u0) * scale;
if (lane * 4 + 1 < kv_len) s_out[lane * 4 + 1] = u32_to_f32(u1) * scale;
if (lane * 4 + 2 < kv_len) s_out[lane * 4 + 2] = u32_to_f32(u2) * scale;
if (lane * 4 + 3 < kv_len) s_out[lane * 4 + 3] = u32_to_f32(u3) * scale;
}
}
__syncthreads();
// Dealloc TMEM
if (wid == 0) {
tmem_dealloc(tmem_base, 128);
}
if (wid == 0) tmem_dealloc(tmem_base, 128);
}
} // namespace

View File

@@ -1,57 +1,40 @@
/**
* DSV4 FMHA — Canonical UMMA SMEM layout for tcgen05.mma.
* DSV4 FMHA — UMMA SMEM layout and descriptor construction.
*
* ==================================================================
* CANONICAL SMEM LAYOUT FOR UMMA ON BLACKWELL SM100
* CANONICAL UMMA LAYOUTS (from cute/atom/mma_traits_sm100.hpp)
* ==================================================================
*
* The tcgen05.mma instruction requires matrix operands in SMEM to be
* laid out in a specific canonical format with swizzle. The UMMA
* descriptor encodes the strides in this canonical layout.
* MN-major NONE: ((1, n), (8, k/8)) strides ((X, SBO), (1, LBO))
* MN-major SW128: ((8, n), (8, k/8)) strides ((1, SBO), (8, LBO))
* K-major NONE: ((8, n), (2, k/2)) strides ((X, SBO), (1, LBO))
* K-major SW128: ((8, n), (2, k/2)) strides ((8, SBO), (1, LBO))
*
* For BF16 operands with SWIZZLE_128B (the default for FMHA):
*
* MN-major A (Q): Swizzle<3,4,3> applied to (128, n_uint128) layout
* K-major B (K): Swizzle<3,4,3> applied to (k_uint128, n_uint128) layout
*
* The swizzle pattern XORs address bits to avoid bank conflicts.
* The hardware expects data in the SWIZZLED order, not row-major.
* For SWIZZLE_NONE, the data is in row-major order (no permutation).
* The descriptor strides come from logical_divide of the SMEM layout
* with Tile(SwizzleAtomMNSize, SwizzleAtomKSize).
*
* ==================================================================
* SWIZZLE_128B PATTERN FOR BF16
* DESCRIPTOR FOR (128, 64) BF16 MN-major NONE
* ==================================================================
*
* Swizzle<3,4,3> XORs bits [6:4] of the uint128_t offset with bits [9:7].
* This permutes 128-bit "rows" within a 1024-bit (8 × 128-bit) block.
*
* For a contiguous row of n_uint128 128-bit values:
* row_offset = row * row_stride_uint128
* swizzled_offset = row_offset ^ ((row_offset & 0x70) >> 3)
*
* But actually, the swizzle in CUTLASS is applied differently:
* base = offset & ~0x3F (clear bottom 6 bits of 128B = 8 uint128_t)
* swizzle_bits = (offset >> 4) ^ (offset >> 7) (3,4,3 pattern)
* swizzled = base | (swizzle_bits & 0x7)
*
* This is getting complex. Let me just implement the swizzle as a
* device function and build the SMEM write/read functions.
* SMEM: (128, 64) row-major BF16, stride (1, 128) in BF16 elements
* uint128_t: (16, 8) stride (1, 16)
* logical_divide Tile(1, 8): ((1,16), (8,1)) strides ((1,1), (16,128))
* → LBO = 16, SBO = 128
*
* ==================================================================
* UMMA DESCRIPTOR FOR SWIZZLE_128B BF16
* DESCRIPTOR FOR (128, 64) BF16 K-major NONE
* ==================================================================
*
* For MN-major A with SWIZZLE_128B:
* layout_type = 2 (SWIZZLE_128B)
* start_address = smem_ptr >> 4
* leading_byte_offset = (row_stride_bytes) >> 4 (in uint128_t units)
* stride_byte_offset = same as leading for simple layout
* version = 1
*
* The descriptor's stride fields are in uint128_t units (16B granularity),
* describing the canonical layout strides AFTER swizzle is applied.
*
* For K-major B, the layout is the same but the K-dimension is the
* "leading" dimension and the MN-dimension is the "stride" dimension.
* SMEM: same data, K is the "leading" dim for K-major descriptor
* For K-major, stride_00 must be 8 (SwizzleAtomMNSize).
* Row-major stride (1, 128) gives stride_00=1 in uint128_t — WRONG.
* K-major requires stride-8 grouping in the MN dimension.
* The data must be (HD, 128) BF16 with stride (128, 1) — K-contiguous.
* uint128_t: (8, 16) stride (16, 1) — K=8 uint128_t contiguous
* logical_divide Tile(8, 2): ((8,1), (2,8)) strides ((1,8), (16,128))
* But stride_00=1, not 8. Still wrong for SW128.
* For NONE, stride_00 can be anything — the assertion is relaxed.
* → LBO = stride_01 = 16, SBO = stride_11 = 128
*/
#pragma once
@@ -59,280 +42,44 @@
namespace dsv4::kernels::attention {
// ==================================================================
// Swizzle<3,4,3> for 128B-swizzled SMEM layout
// ==================================================================
// The swizzle XORs bits to permute 128-bit rows within 1024-bit blocks.
// This avoids bank conflicts when the MMA engine reads SMEM.
//
// Swizzle<3,4,3> means: XOR 3 bits starting at bit 4 with 3 bits
// starting at bit 7. The result is a permutation of 8 consecutive
// 128-bit rows within each 1024-bit (128-byte = 8 × 16-byte) block.
//
// Input: 128-bit row index within a 128B block (0-7)
// Output: permuted 128-bit row index within the block
// ==================================================================
__device__ __forceinline__ int swizzle_128b(int offset_128b) {
// Swizzle<3,4,3>: XOR bits [6:4] with bits [9:7]
// offset_128b is the uint128_t index
// Within a 128B block (8 uint128_t), the swizzle permutes the 8 rows
int block_base = offset_128b & ~7; // Clear bottom 3 bits (8-aligned)
int row_in_block = offset_128b & 7; // 0-7 within the block
// Swizzle<3,4,3> on the row index:
// XOR the 3 bits of row_in_block with the 3 bits of the block index
// But wait — the swizzle is on the BYTE address, not the row index.
// Let me be more precise.
//
// The Swizzle<3,4,3> operates on the 128-bit vector index.
// For a contiguous layout, vector index = row * stride + col.
// The swizzle XORs bits [6:4] (3 bits at position 4) with bits [9:7] (3 bits at position 7).
// Since each vector is 16 bytes, bit 4 corresponds to 16*2^4 = 256 bytes offset,
// and bit 7 corresponds to 16*2^7 = 2048 bytes offset.
//
// Actually, the swizzle is applied to the LAYOUT, not the address.
// The CuTe Swizzle<3,4,3> with a contiguous layout means:
// For each element at offset i:
// swizzled_i = i ^ ((i >> 3) & 0x7) << 4)
//
// Hmm, this isn't right either. Let me look at the CuTe implementation.
// Simple approach: the swizzle permutes the 8 rows within each
// 128B block. The permutation is:
// row 0 → row 0
// row 1 → row 1
// row 2 → row 2
// row 3 → row 3
// row 4 → row 4
// row 5 → row 5
// row 6 → row 6
// row 7 → row 7
// Wait, Swizzle<3,4,3> with a stride-1 layout on 8 elements:
// The XOR is: (i & 0x70) ^ ((i & 0x0E) << 3) ... no.
//
// Let me just compute it from the CuTe definition.
// Swizzle<3,4,3> means: baz=3, b4=4, b3=3
// The swizzle XORs (base >> b3) with (base >> (b3+b4)), taking baz bits.
// So: ((offset >> 3) ^ (offset >> 7)) & 0x7
// This is applied to the uint128_t offset within the full layout.
int swizzled_row = row_in_block ^ (((offset_128b >> 3) ^ (offset_128b >> 7)) & 0x7);
// Actually, the swizzle is on the ELEMENT offset, not the row-in-block.
// Let me redo this properly.
// For Swizzle<3,4,3> applied to offset i (in uint128_t units):
// swizzled = i ^ (((i >> 3) ^ (i >> 7)) & 0x7) << 4)
// Wait, that shifts by 4 bits, but we're in uint128_t units.
// The CuTe Swizzle<3,4,3> definition:
// template <int baz, int b4, int b3>
// struct Swizzle {
// CUTE_HOST_DEVICE constexpr
// uint64_t operator()(uint64_t offset) const {
// return offset ^ ((offset >> b3) ^ (offset >> (b3+b4))) & ((1 << baz) - 1)) << b3;
// }
// };
// For baz=3, b4=4, b3=3:
// swizzled = offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3
// This is in ELEMENT units (BF16), not uint128_t units.
// But we need to apply it to the uint128_t offset.
// The atom layout for BF16 MN_SW128 has shape (1024, 8) with stride (1, 1024).
// In uint128_t: (128, 8) with stride (1, 128).
// The swizzle is applied to the 1D offset = row + col * 128 (in uint128_t units).
// Actually, I realize the swizzle operates on the 1D address (in the element
// space), not on the 2D coordinates. The CuTe layout maps (row, col) to a 1D
// offset, then the swizzle permutes that offset.
// For MN_SW128 BF16:
// Layout: (1024, 8) with stride (1, 1024)
// 1D offset = m + n * 1024 (in BF16 elements)
// Swizzle: offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3
// To write BF16 element at (row, col) in the swizzled SMEM:
// 1. Compute 1D offset = row + col * 1024 (BF16 element index)
// 2. Apply swizzle: swizzled = offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3
// 3. SMEM byte address = smem_base + swizzled * 2 (2 bytes per BF16)
// To convert to uint128_t units:
// uint128_offset = swizzled / 8
// But the swizzle operates on BF16 elements, not uint128_t.
// This is the key: the swizzle is on the ELEMENT offset, not the vector offset.
// So I need to compute the element-level swizzled offset, then convert to bytes.
// For our (128, HD) matrix with MN_SW128:
// HD=64 BF16 → n = 8 uint128_t per row
// But the layout atom has n = 8 elements in the N dimension (in uint128_t: 1)
// Wait, the atom shape is (1024, 8) in BF16 = (128, 1) in uint128_t.
// But HD=64 means the K dimension is 64 BF16 = 4 uint128_t.
// I think the confusion is that the atom shape (1024, 8) means 1024 rows × 8 columns
// in BF16. For our matrix (128, 64), we tile the atom:
// M: 128 rows, atom covers 1024 → 1 tile (atom is bigger than matrix)
// N: 64 columns, atom covers 8 → 8 tiles
// So the full layout is: tile_to_shape(atom, (128, 64))
// = repeat the 8-column atom 8 times along the N dimension
// OK let me just compute this numerically for a few elements to verify.
return block_base + (row_in_block & 7); // placeholder
}
// ==================================================================
// UMMA SMEM write: write a BF16 element to swizzled SMEM
// ==================================================================
// Given a (row, col) position in the logical matrix, compute the
// swizzled SMEM address and write the BF16 value.
// ==================================================================
__device__ __forceinline__ void umma_smem_write_mn_sw128(
bf16_t* smem_base, int row, int col, int hd, bf16_t val
) {
// MN_SW128 BF16 layout: atom shape (1024, 8), stride (1, 1024)
// For (128, HD) matrix:
// 1D element offset = row + col * 1024 (strided by 1024 in the N dimension)
// Wait, that's wrong. The N-dim stride is 1024 because the atom has 1024 rows.
// But our matrix only has 128 rows. The extra 896 rows are padding.
// Actually, the full tiled layout for (128, HD) with atom (1024, 8) stride (1, 1024):
// logical_offset(m, n) = (m % 1024) + (m / 1024) * 1024 * 8 + (n % 8) * 1024 + (n / 8) * 8
// Wait, tile_to_shape repeats the atom pattern.
// For MN_SW128, the K-dim stride is 1024 BF16 elements per atom column.
// With our 128-row matrix (which fits in 1 atom column of 1024 rows):
// logical_offset(m, n) = m + (n / 8) * 1024 * 8 + (n % 8) * 1024
// = m + (n / 8) * 8192 + (n % 8) * 1024
// Then swizzle: swizzled_offset = logical_offset ^ (((logical_offset >> 3) ^ (logical_offset >> 7)) & 0x7) << 3
int logical = row + (col / 8) * 8192 + (col % 8) * 1024;
int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3);
smem_base[swizzled] = val;
}
// ==================================================================
// UMMA SMEM read: read a BF16 element from swizzled SMEM
// ==================================================================
__device__ __forceinline__ bf16_t umma_smem_read_mn_sw128(
const bf16_t* smem_base, int row, int col, int hd
) {
int logical = row + (col / 8) * 8192 + (col % 8) * 1024;
int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3);
return smem_base[swizzled];
}
// ==================================================================
// K-major SMEM write/read for SW128
// ==================================================================
// For K-major B (K^T): the matrix is stored with K as the major dimension.
// The atom for K_SW128 is the same shape but with K as the leading dim.
// Layout: (k_elements, n_atoms) with K-major strides.
// ==================================================================
__device__ __forceinline__ void umma_smem_write_k_sw128(
bf16_t* smem_base, int row, int col, int hd, bf16_t val
) {
// K_SW128 BF16: atom shape (8, 1024), stride (1, 8)
// For (128, HD) matrix in K-major:
// row = K index (0..127), col = MN index (0..HD-1)
// logical_offset = (k % 8) + (k / 8) * 8 * 128 + (mn % 8) * 8 + (mn / 8) * 1024
// Hmm, this needs to be worked out properly.
// For K_SW128 BF16 atom: Shape<(8, 1024)>, Stride<(1, 8)>
// In uint128_t: Shape<(1, 128)>, Stride<(1, 1)>
// tile_to_shape for (128, HD):
// K: 128 BF16 = 16 uint128_t → 16 tiles of the K atom (1 uint128_t each)
// MN: HD BF16 = 8 uint128_t → 8 tiles of the MN atom (128 uint128_t each)
// Actually for K_SW128:
// atom: (8, 1024) BF16, stride (1, 8)
// For (128, 64) matrix:
// logical_offset(k, mn) = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024 + (mn / 1024) * 8
// = (k % 8) + mn * 8 + (k / 8) * 8192
// Wait, this doesn't seem right. The K_SW128 atom for BF16 is:
// Shape<(8, 1024)> in elements, Stride<(1, 8)>
// This means: 8 contiguous elements along K, then stride 8 to the next group
// 1024 groups along MN
// For our (128, 64) matrix in K-major:
// k ranges 0..127, mn ranges 0..63
// Atom covers 8 K and 1024 MN. Need 16 K-tiles and 1 MN-tile.
// logical_offset(k, mn) = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024
// = (k % 8) + mn * 8 + (k / 8) * 8192
int logical = (row % 8) + col * 8 + (row / 8) * 8192;
int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3);
smem_base[swizzled] = val;
}
// ==================================================================
// UMMA SMEM descriptor construction (correct format)
// ==================================================================
enum class UmmaMajor { MN, K };
enum class UmmaLayout { NONE = 0, MN_SW128 = 2, K_SW128 = 2 };
__device__ __forceinline__ uint64_t make_umma_desc(
uint32_t smem_ptr,
int dim_m, int dim_n,
UmmaMajor major, UmmaLayout layout = UmmaLayout::MN_SW128
/**
* MN-major SWIZZLE_NONE descriptor for (128, HD) BF16 row-major.
*/
__device__ __forceinline__ uint64_t make_umma_desc_mn_none(
uint32_t smem_ptr, int hd
) {
uint64_t desc = 0;
// start_address: bits [0,14) = smem_ptr >> 4 (in 16B units)
desc |= (static_cast<uint64_t>(smem_ptr >> 4) & 0x3FFF);
// For SW128 layout:
// The leading_byte_offset and stride_byte_offset describe the canonical
// layout in uint128_t units.
//
// MN_SW128: atom (1024, 8) stride (1, 1024) in BF16 elements
// = (128, 1) stride (1, 128) in uint128_t units
// leading_byte_offset (K-dim stride) = 1024 BF16 * 2 / 16 = 128 uint128_t
// stride_byte_offset (inter-tile MN stride) = same for 1 MN tile
//
// K_SW128: atom (8, 1024) stride (1, 8) in BF16 elements
// = (1, 128) stride (1, 1) in uint128_t units
// leading_byte_offset (MN-dim stride) = 8 BF16 * 2 / 16 = 1 uint128_t
// stride_byte_offset (inter-tile K stride) = 8192 BF16 * 2 / 16 = 1024 uint128_t
if (major == UmmaMajor::MN) {
// MN_SW128:
// leading_byte_offset = 128 (uint128_t units) = 128 * 16 = 2048 bytes
// stride_byte_offset = 128 (uint128_t units) — same for simple case
int lbo = 128; // 1024 BF16 * 2 bytes / 16 = 128 uint128_t
int sbo = 128;
desc |= (static_cast<uint64_t>(lbo) & 0x3FFF) << 16;
desc |= (static_cast<uint64_t>(sbo) & 0x3FFF) << 32;
} else {
// K_SW128:
// leading_byte_offset = 1 (uint128_t units) = 8 BF16 per MN group
// stride_byte_offset = 1024 (uint128_t units)
int lbo = 1;
int sbo = 1024;
desc |= (static_cast<uint64_t>(lbo) & 0x3FFF) << 16;
desc |= (static_cast<uint64_t>(sbo) & 0x3FFF) << 32;
}
// version: bits [46,48) = 1
desc |= (static_cast<uint64_t>(1) << 46);
// layout_type: bits [61,64) = 2 (SWIZZLE_128B)
desc |= (static_cast<uint64_t>(2) << 61);
// LBO = 16 (row stride in uint128_t for 128 rows)
desc |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
// SBO = 128 (8-row group stride in uint128_t)
desc |= (static_cast<uint64_t>(128) & 0x3FFF) << 32;
desc |= (static_cast<uint64_t>(1) << 46); // version
// layout_type = 0 (NONE)
return desc;
}
// ==================================================================
// tcgen05.mma PTX wrappers
// ==================================================================
/**
* K-major SWIZZLE_NONE descriptor for (128, HD) BF16 row-major.
* The same SMEM data is used, but the descriptor tells the MMA to
* interpret it as K-major (transposed view).
* For K-major with stride (1, 128) BF16:
* uint128_t: (16, 8) stride (1, 16)
* logical_divide Tile(8, 2): ((8,2), (2,4)) strides ((1,8), (16,32))
* LBO = stride_01 = 16, SBO = stride_11 = 32
*/
__device__ __forceinline__ uint64_t make_umma_desc_k_none(
uint32_t smem_ptr, int hd
) {
uint64_t desc = 0;
desc |= (static_cast<uint64_t>(smem_ptr >> 4) & 0x3FFF);
desc |= (static_cast<uint64_t>(16) & 0x3FFF) << 16; // LBO
desc |= (static_cast<uint64_t>(32) & 0x3FFF) << 32; // SBO
desc |= (static_cast<uint64_t>(1) << 46); // version
return desc;
}
// tcgen05.mma SS: S←S (both SMEM → TMEM). One lane per warp.
__device__ void umma_ss_f16(
uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b,
bool accumulate = false
@@ -340,21 +87,19 @@ __device__ void umma_ss_f16(
uint32_t idescE = static_cast<uint32_t>(desc_a >> 32);
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
uint32_t mask0 = 0, mask1 = 0, mask2 = 0, mask3 = 0;
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p;\n\t"
"}"
:: "r"(tmem_c),
"l"(desc_a), "l"(desc_b),
"r"(idescE),
"r"(scaleC_bits),
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b),
"r"(idescE), "r"(scaleC_bits),
"r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)
);
}
// tcgen05.mma TS: T←S (TMEM A + SMEM B → TMEM C). One lane per warp.
__device__ void umma_ts_f16(
uint32_t tmem_c, uint32_t tmem_a, uint64_t desc_b,
bool accumulate = true
@@ -362,17 +107,14 @@ __device__ void umma_ts_f16(
uint32_t idescE = static_cast<uint32_t>(desc_b >> 32);
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
uint32_t mask0 = 0, mask1 = 0, mask2 = 0, mask3 = 0;
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p;\n\t"
:: "r"(tmem_c),
"r"(tmem_a),
"l"(desc_b),
"r"(idescE),
"r"(scaleC_bits),
"}"
:: "r"(tmem_c), "r"(tmem_a), "l"(desc_b),
"r"(idescE), "r"(scaleC_bits),
"r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)
);
}