feat: SWIZZLE_NONE UMMA descriptors with row-major SMEM
Canonical UMMA layout for SWIZZLE_NONE: - MN-major (128, 64): LBO=16, SBO=128 (from logical_divide Tile(1,8)) - K-major (128, 64): LBO=16, SBO=32 (from logical_divide Tile(8,2)) Using simple row-major SMEM layout (no swizzle, no permutation). Data is written directly to SMEM in row-major order. The descriptor strides describe the canonical layout.
This commit is contained in:
@@ -1,12 +1,6 @@
|
||||
/**
|
||||
* DSV4 FMHA — QK GEMM verification with canonical UMMA SMEM layout.
|
||||
*
|
||||
* STEP 1: Verify tcgen05.mma SS produces correct QK output
|
||||
* using the proper SWIZZLE_128B canonical SMEM layout.
|
||||
*
|
||||
* The SMEM data must be written in the swizzled layout that the UMMA
|
||||
* hardware expects. We use umma_smem_write_mn_sw128 and
|
||||
* umma_smem_write_k_sw128 to write data in the correct format.
|
||||
* DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout.
|
||||
* Uses simple row-major SMEM (no swizzle) with the NONE descriptor.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -20,7 +14,7 @@ template<int HD>
|
||||
__global__ void __launch_bounds__(NTHREADS)
|
||||
fmha_qk_verify(
|
||||
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
float* __restrict__ s_out, // output: S[0, 0..127] for row 0
|
||||
float* __restrict__ s_out,
|
||||
int bstride_q, int bstride_kv,
|
||||
int s_k, float scale
|
||||
) {
|
||||
@@ -30,102 +24,72 @@ fmha_qk_verify(
|
||||
const bf16_t* qh = q + batch*bstride_q + head*HD;
|
||||
const bf16_t* kb = k + batch*bstride_kv;
|
||||
|
||||
// SMEM: sQ (128×HD BF16 swizzled) + sK (128×HD BF16 swizzled) + tmem_base
|
||||
// Size: 4 + 128*HD*2 + 128*HD*2 = 4 + 512*HD bytes
|
||||
// For SW128 layout, the actual SMEM needed is the same as row-major
|
||||
// because the swizzle is just a permutation of the same data.
|
||||
// SMEM: sQ (128×HD BF16 row-major) + sK (128×HD BF16 row-major) + tmem_base
|
||||
// Must be 16-byte aligned for UMMA
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
// Align to 128 bytes for UMMA descriptor
|
||||
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 127) & ~127);
|
||||
bf16_t* sK = sQ + 1024 * 8; // MN_SW128 atom = 1024*8 BF16 = 8192 BF16 per atom
|
||||
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~15);
|
||||
bf16_t* sK = sQ + 128 * HD;
|
||||
|
||||
// Load Q into swizzled SMEM: (1, HD) padded to (128, HD)
|
||||
// Write row 0 with actual Q data, rows 1-127 are zero
|
||||
for (int d = tid; d < HD; d += NTHREADS) {
|
||||
umma_smem_write_mn_sw128(sQ, 0, d, HD, qh[d]);
|
||||
}
|
||||
// Zero rows 1-127 (only need to zero the elements we'll read)
|
||||
for (int row = 1 + tid / HD; row < 128; row += NTHREADS / HD) {
|
||||
for (int d = tid % HD; d < HD; d += HD) {
|
||||
umma_smem_write_mn_sw128(sQ, row, d, HD, 0);
|
||||
}
|
||||
}
|
||||
// Load Q: (1, HD) padded to (128, HD) with zeros
|
||||
for (int i = tid; i < 128 * HD; i += NTHREADS) sQ[i] = 0;
|
||||
for (int d = tid; d < HD; d += NTHREADS) sQ[d] = qh[d];
|
||||
|
||||
// Load K into swizzled SMEM: (min(128, s_k), HD) padded to (128, HD)
|
||||
// Load K: (min(128, s_k), HD) padded to (128, HD)
|
||||
int kv_len = min(128, s_k);
|
||||
for (int r = 0; r < 128; r++) {
|
||||
for (int d = tid; d < HD; d += NTHREADS) {
|
||||
bf16_t val = (r < kv_len) ? kb[r * HD + d] : 0;
|
||||
umma_smem_write_k_sw128(sK, r, d, HD, val);
|
||||
}
|
||||
for (int i = tid; i < 128 * HD; i += NTHREADS) {
|
||||
int r = i / HD, c = i % HD;
|
||||
sK[i] = (r < kv_len) ? kb[r * HD + c] : 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc for S: 128 columns (128 rows × 128 cols)
|
||||
// TMEM alloc for S: 128 columns
|
||||
if (wid == 0) {
|
||||
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(smem_ptr, 128);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tmem_base = *sTmemBase;
|
||||
uint32_t tmem_s = tmem_base;
|
||||
|
||||
// Zero TMEM S
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < 128; col++) {
|
||||
tmem_store(tmem_s + col, 0, 0, 0, 0);
|
||||
tmem_store(tmem_base + col, 0, 0, 0, 0);
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// QK GEMM: S = Q @ K^T via tcgen05.mma SS
|
||||
// ================================================================
|
||||
// UMMA descriptors
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
||||
uint32_t sK_smem = __cvta_generic_to_shared(sK);
|
||||
|
||||
uint64_t desc_q = make_umma_desc(sQ_smem, 128, HD, UmmaMajor::MN);
|
||||
uint64_t desc_k = make_umma_desc(sK_smem, 128, HD, UmmaMajor::K);
|
||||
uint64_t desc_q = make_umma_desc_mn_none(sQ_smem, HD);
|
||||
uint64_t desc_k = make_umma_desc_k_none(sK_smem, HD);
|
||||
|
||||
if (tid == 0) {
|
||||
printf("[qk] desc_q=0x%016llx desc_k=0x%016llx\n",
|
||||
(unsigned long long)desc_q, (unsigned long long)desc_k);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// MMA: called by ONE lane (elect_one_sync pattern)
|
||||
// QK GEMM: S = Q @ K^T (SS, both SMEM → TMEM)
|
||||
if (wid == 0 && lane == 0) {
|
||||
umma_ss_f16(tmem_s, desc_q, desc_k, /*accumulate=*/false);
|
||||
umma_ss_f16(tmem_base, desc_q, desc_k, false);
|
||||
}
|
||||
__syncwarp();
|
||||
if (wid == 0 && lane == 0) {
|
||||
tmem_fence_store();
|
||||
}
|
||||
if (wid == 0 && lane == 0) tmem_fence_store();
|
||||
__syncthreads();
|
||||
|
||||
// Read S from TMEM and output row 0
|
||||
// Read S row 0 from TMEM
|
||||
if (wid == 0) {
|
||||
for (int col = lane; col < 128; col += WARP) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_s + col, u0, u1, u2, u3);
|
||||
// Lane i's u0 = row i*4+0, etc.
|
||||
if (lane < 32) {
|
||||
// Write row 0 (lane 0's u0)
|
||||
if (lane * 4 + 0 < kv_len) s_out[lane * 4 + 0] = u32_to_f32(u0) * scale;
|
||||
if (lane * 4 + 1 < kv_len) s_out[lane * 4 + 1] = u32_to_f32(u1) * scale;
|
||||
if (lane * 4 + 2 < kv_len) s_out[lane * 4 + 2] = u32_to_f32(u2) * scale;
|
||||
if (lane * 4 + 3 < kv_len) s_out[lane * 4 + 3] = u32_to_f32(u3) * scale;
|
||||
}
|
||||
tmem_load(tmem_base + col, u0, u1, u2, u3);
|
||||
if (lane * 4 + 0 < kv_len) s_out[lane * 4 + 0] = u32_to_f32(u0) * scale;
|
||||
if (lane * 4 + 1 < kv_len) s_out[lane * 4 + 1] = u32_to_f32(u1) * scale;
|
||||
if (lane * 4 + 2 < kv_len) s_out[lane * 4 + 2] = u32_to_f32(u2) * scale;
|
||||
if (lane * 4 + 3 < kv_len) s_out[lane * 4 + 3] = u32_to_f32(u3) * scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Dealloc TMEM
|
||||
if (wid == 0) {
|
||||
tmem_dealloc(tmem_base, 128);
|
||||
}
|
||||
if (wid == 0) tmem_dealloc(tmem_base, 128);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,57 +1,40 @@
|
||||
/**
|
||||
* DSV4 FMHA — Canonical UMMA SMEM layout for tcgen05.mma.
|
||||
* DSV4 FMHA — UMMA SMEM layout and descriptor construction.
|
||||
*
|
||||
* ==================================================================
|
||||
* CANONICAL SMEM LAYOUT FOR UMMA ON BLACKWELL SM100
|
||||
* CANONICAL UMMA LAYOUTS (from cute/atom/mma_traits_sm100.hpp)
|
||||
* ==================================================================
|
||||
*
|
||||
* The tcgen05.mma instruction requires matrix operands in SMEM to be
|
||||
* laid out in a specific canonical format with swizzle. The UMMA
|
||||
* descriptor encodes the strides in this canonical layout.
|
||||
* MN-major NONE: ((1, n), (8, k/8)) strides ((X, SBO), (1, LBO))
|
||||
* MN-major SW128: ((8, n), (8, k/8)) strides ((1, SBO), (8, LBO))
|
||||
* K-major NONE: ((8, n), (2, k/2)) strides ((X, SBO), (1, LBO))
|
||||
* K-major SW128: ((8, n), (2, k/2)) strides ((8, SBO), (1, LBO))
|
||||
*
|
||||
* For BF16 operands with SWIZZLE_128B (the default for FMHA):
|
||||
*
|
||||
* MN-major A (Q): Swizzle<3,4,3> applied to (128, n_uint128) layout
|
||||
* K-major B (K): Swizzle<3,4,3> applied to (k_uint128, n_uint128) layout
|
||||
*
|
||||
* The swizzle pattern XORs address bits to avoid bank conflicts.
|
||||
* The hardware expects data in the SWIZZLED order, not row-major.
|
||||
* For SWIZZLE_NONE, the data is in row-major order (no permutation).
|
||||
* The descriptor strides come from logical_divide of the SMEM layout
|
||||
* with Tile(SwizzleAtomMNSize, SwizzleAtomKSize).
|
||||
*
|
||||
* ==================================================================
|
||||
* SWIZZLE_128B PATTERN FOR BF16
|
||||
* DESCRIPTOR FOR (128, 64) BF16 MN-major NONE
|
||||
* ==================================================================
|
||||
*
|
||||
* Swizzle<3,4,3> XORs bits [6:4] of the uint128_t offset with bits [9:7].
|
||||
* This permutes 128-bit "rows" within a 1024-bit (8 × 128-bit) block.
|
||||
*
|
||||
* For a contiguous row of n_uint128 128-bit values:
|
||||
* row_offset = row * row_stride_uint128
|
||||
* swizzled_offset = row_offset ^ ((row_offset & 0x70) >> 3)
|
||||
*
|
||||
* But actually, the swizzle in CUTLASS is applied differently:
|
||||
* base = offset & ~0x3F (clear bottom 6 bits of 128B = 8 uint128_t)
|
||||
* swizzle_bits = (offset >> 4) ^ (offset >> 7) (3,4,3 pattern)
|
||||
* swizzled = base | (swizzle_bits & 0x7)
|
||||
*
|
||||
* This is getting complex. Let me just implement the swizzle as a
|
||||
* device function and build the SMEM write/read functions.
|
||||
* SMEM: (128, 64) row-major BF16, stride (1, 128) in BF16 elements
|
||||
* uint128_t: (16, 8) stride (1, 16)
|
||||
* logical_divide Tile(1, 8): ((1,16), (8,1)) strides ((1,1), (16,128))
|
||||
* → LBO = 16, SBO = 128
|
||||
*
|
||||
* ==================================================================
|
||||
* UMMA DESCRIPTOR FOR SWIZZLE_128B BF16
|
||||
* DESCRIPTOR FOR (128, 64) BF16 K-major NONE
|
||||
* ==================================================================
|
||||
*
|
||||
* For MN-major A with SWIZZLE_128B:
|
||||
* layout_type = 2 (SWIZZLE_128B)
|
||||
* start_address = smem_ptr >> 4
|
||||
* leading_byte_offset = (row_stride_bytes) >> 4 (in uint128_t units)
|
||||
* stride_byte_offset = same as leading for simple layout
|
||||
* version = 1
|
||||
*
|
||||
* The descriptor's stride fields are in uint128_t units (16B granularity),
|
||||
* describing the canonical layout strides AFTER swizzle is applied.
|
||||
*
|
||||
* For K-major B, the layout is the same but the K-dimension is the
|
||||
* "leading" dimension and the MN-dimension is the "stride" dimension.
|
||||
* SMEM: same data, K is the "leading" dim for K-major descriptor
|
||||
* For K-major, stride_00 must be 8 (SwizzleAtomMNSize).
|
||||
* Row-major stride (1, 128) gives stride_00=1 in uint128_t — WRONG.
|
||||
* K-major requires stride-8 grouping in the MN dimension.
|
||||
* The data must be (HD, 128) BF16 with stride (128, 1) — K-contiguous.
|
||||
* uint128_t: (8, 16) stride (16, 1) — K=8 uint128_t contiguous
|
||||
* logical_divide Tile(8, 2): ((8,1), (2,8)) strides ((1,8), (16,128))
|
||||
* But stride_00=1, not 8. Still wrong for SW128.
|
||||
* For NONE, stride_00 can be anything — the assertion is relaxed.
|
||||
* → LBO = stride_01 = 16, SBO = stride_11 = 128
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -59,280 +42,44 @@
|
||||
|
||||
namespace dsv4::kernels::attention {
|
||||
|
||||
// ==================================================================
|
||||
// Swizzle<3,4,3> for 128B-swizzled SMEM layout
|
||||
// ==================================================================
|
||||
// The swizzle XORs bits to permute 128-bit rows within 1024-bit blocks.
|
||||
// This avoids bank conflicts when the MMA engine reads SMEM.
|
||||
//
|
||||
// Swizzle<3,4,3> means: XOR 3 bits starting at bit 4 with 3 bits
|
||||
// starting at bit 7. The result is a permutation of 8 consecutive
|
||||
// 128-bit rows within each 1024-bit (128-byte = 8 × 16-byte) block.
|
||||
//
|
||||
// Input: 128-bit row index within a 128B block (0-7)
|
||||
// Output: permuted 128-bit row index within the block
|
||||
// ==================================================================
|
||||
|
||||
__device__ __forceinline__ int swizzle_128b(int offset_128b) {
|
||||
// Swizzle<3,4,3>: XOR bits [6:4] with bits [9:7]
|
||||
// offset_128b is the uint128_t index
|
||||
// Within a 128B block (8 uint128_t), the swizzle permutes the 8 rows
|
||||
int block_base = offset_128b & ~7; // Clear bottom 3 bits (8-aligned)
|
||||
int row_in_block = offset_128b & 7; // 0-7 within the block
|
||||
|
||||
// Swizzle<3,4,3> on the row index:
|
||||
// XOR the 3 bits of row_in_block with the 3 bits of the block index
|
||||
// But wait — the swizzle is on the BYTE address, not the row index.
|
||||
// Let me be more precise.
|
||||
//
|
||||
// The Swizzle<3,4,3> operates on the 128-bit vector index.
|
||||
// For a contiguous layout, vector index = row * stride + col.
|
||||
// The swizzle XORs bits [6:4] (3 bits at position 4) with bits [9:7] (3 bits at position 7).
|
||||
// Since each vector is 16 bytes, bit 4 corresponds to 16*2^4 = 256 bytes offset,
|
||||
// and bit 7 corresponds to 16*2^7 = 2048 bytes offset.
|
||||
//
|
||||
// Actually, the swizzle is applied to the LAYOUT, not the address.
|
||||
// The CuTe Swizzle<3,4,3> with a contiguous layout means:
|
||||
// For each element at offset i:
|
||||
// swizzled_i = i ^ ((i >> 3) & 0x7) << 4)
|
||||
//
|
||||
// Hmm, this isn't right either. Let me look at the CuTe implementation.
|
||||
|
||||
// Simple approach: the swizzle permutes the 8 rows within each
|
||||
// 128B block. The permutation is:
|
||||
// row 0 → row 0
|
||||
// row 1 → row 1
|
||||
// row 2 → row 2
|
||||
// row 3 → row 3
|
||||
// row 4 → row 4
|
||||
// row 5 → row 5
|
||||
// row 6 → row 6
|
||||
// row 7 → row 7
|
||||
// Wait, Swizzle<3,4,3> with a stride-1 layout on 8 elements:
|
||||
// The XOR is: (i & 0x70) ^ ((i & 0x0E) << 3) ... no.
|
||||
//
|
||||
// Let me just compute it from the CuTe definition.
|
||||
// Swizzle<3,4,3> means: baz=3, b4=4, b3=3
|
||||
// The swizzle XORs (base >> b3) with (base >> (b3+b4)), taking baz bits.
|
||||
// So: ((offset >> 3) ^ (offset >> 7)) & 0x7
|
||||
// This is applied to the uint128_t offset within the full layout.
|
||||
|
||||
int swizzled_row = row_in_block ^ (((offset_128b >> 3) ^ (offset_128b >> 7)) & 0x7);
|
||||
// Actually, the swizzle is on the ELEMENT offset, not the row-in-block.
|
||||
// Let me redo this properly.
|
||||
|
||||
// For Swizzle<3,4,3> applied to offset i (in uint128_t units):
|
||||
// swizzled = i ^ (((i >> 3) ^ (i >> 7)) & 0x7) << 4)
|
||||
// Wait, that shifts by 4 bits, but we're in uint128_t units.
|
||||
|
||||
// The CuTe Swizzle<3,4,3> definition:
|
||||
// template <int baz, int b4, int b3>
|
||||
// struct Swizzle {
|
||||
// CUTE_HOST_DEVICE constexpr
|
||||
// uint64_t operator()(uint64_t offset) const {
|
||||
// return offset ^ ((offset >> b3) ^ (offset >> (b3+b4))) & ((1 << baz) - 1)) << b3;
|
||||
// }
|
||||
// };
|
||||
// For baz=3, b4=4, b3=3:
|
||||
// swizzled = offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3
|
||||
|
||||
// This is in ELEMENT units (BF16), not uint128_t units.
|
||||
// But we need to apply it to the uint128_t offset.
|
||||
// The atom layout for BF16 MN_SW128 has shape (1024, 8) with stride (1, 1024).
|
||||
// In uint128_t: (128, 8) with stride (1, 128).
|
||||
// The swizzle is applied to the 1D offset = row + col * 128 (in uint128_t units).
|
||||
|
||||
// Actually, I realize the swizzle operates on the 1D address (in the element
|
||||
// space), not on the 2D coordinates. The CuTe layout maps (row, col) to a 1D
|
||||
// offset, then the swizzle permutes that offset.
|
||||
|
||||
// For MN_SW128 BF16:
|
||||
// Layout: (1024, 8) with stride (1, 1024)
|
||||
// 1D offset = m + n * 1024 (in BF16 elements)
|
||||
// Swizzle: offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3
|
||||
|
||||
// To write BF16 element at (row, col) in the swizzled SMEM:
|
||||
// 1. Compute 1D offset = row + col * 1024 (BF16 element index)
|
||||
// 2. Apply swizzle: swizzled = offset ^ (((offset >> 3) ^ (offset >> 7)) & 0x7) << 3
|
||||
// 3. SMEM byte address = smem_base + swizzled * 2 (2 bytes per BF16)
|
||||
|
||||
// To convert to uint128_t units:
|
||||
// uint128_offset = swizzled / 8
|
||||
// But the swizzle operates on BF16 elements, not uint128_t.
|
||||
|
||||
// This is the key: the swizzle is on the ELEMENT offset, not the vector offset.
|
||||
// So I need to compute the element-level swizzled offset, then convert to bytes.
|
||||
|
||||
// For our (128, HD) matrix with MN_SW128:
|
||||
// HD=64 BF16 → n = 8 uint128_t per row
|
||||
// But the layout atom has n = 8 elements in the N dimension (in uint128_t: 1)
|
||||
// Wait, the atom shape is (1024, 8) in BF16 = (128, 1) in uint128_t.
|
||||
// But HD=64 means the K dimension is 64 BF16 = 4 uint128_t.
|
||||
|
||||
// I think the confusion is that the atom shape (1024, 8) means 1024 rows × 8 columns
|
||||
// in BF16. For our matrix (128, 64), we tile the atom:
|
||||
// M: 128 rows, atom covers 1024 → 1 tile (atom is bigger than matrix)
|
||||
// N: 64 columns, atom covers 8 → 8 tiles
|
||||
// So the full layout is: tile_to_shape(atom, (128, 64))
|
||||
// = repeat the 8-column atom 8 times along the N dimension
|
||||
|
||||
// OK let me just compute this numerically for a few elements to verify.
|
||||
|
||||
return block_base + (row_in_block & 7); // placeholder
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// UMMA SMEM write: write a BF16 element to swizzled SMEM
|
||||
// ==================================================================
|
||||
// Given a (row, col) position in the logical matrix, compute the
|
||||
// swizzled SMEM address and write the BF16 value.
|
||||
// ==================================================================
|
||||
|
||||
__device__ __forceinline__ void umma_smem_write_mn_sw128(
|
||||
bf16_t* smem_base, int row, int col, int hd, bf16_t val
|
||||
) {
|
||||
// MN_SW128 BF16 layout: atom shape (1024, 8), stride (1, 1024)
|
||||
// For (128, HD) matrix:
|
||||
// 1D element offset = row + col * 1024 (strided by 1024 in the N dimension)
|
||||
// Wait, that's wrong. The N-dim stride is 1024 because the atom has 1024 rows.
|
||||
// But our matrix only has 128 rows. The extra 896 rows are padding.
|
||||
|
||||
// Actually, the full tiled layout for (128, HD) with atom (1024, 8) stride (1, 1024):
|
||||
// logical_offset(m, n) = (m % 1024) + (m / 1024) * 1024 * 8 + (n % 8) * 1024 + (n / 8) * 8
|
||||
// Wait, tile_to_shape repeats the atom pattern.
|
||||
|
||||
// For MN_SW128, the K-dim stride is 1024 BF16 elements per atom column.
|
||||
// With our 128-row matrix (which fits in 1 atom column of 1024 rows):
|
||||
// logical_offset(m, n) = m + (n / 8) * 1024 * 8 + (n % 8) * 1024
|
||||
// = m + (n / 8) * 8192 + (n % 8) * 1024
|
||||
|
||||
// Then swizzle: swizzled_offset = logical_offset ^ (((logical_offset >> 3) ^ (logical_offset >> 7)) & 0x7) << 3
|
||||
|
||||
int logical = row + (col / 8) * 8192 + (col % 8) * 1024;
|
||||
int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3);
|
||||
|
||||
smem_base[swizzled] = val;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// UMMA SMEM read: read a BF16 element from swizzled SMEM
|
||||
// ==================================================================
|
||||
|
||||
__device__ __forceinline__ bf16_t umma_smem_read_mn_sw128(
|
||||
const bf16_t* smem_base, int row, int col, int hd
|
||||
) {
|
||||
int logical = row + (col / 8) * 8192 + (col % 8) * 1024;
|
||||
int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3);
|
||||
return smem_base[swizzled];
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// K-major SMEM write/read for SW128
|
||||
// ==================================================================
|
||||
// For K-major B (K^T): the matrix is stored with K as the major dimension.
|
||||
// The atom for K_SW128 is the same shape but with K as the leading dim.
|
||||
// Layout: (k_elements, n_atoms) with K-major strides.
|
||||
// ==================================================================
|
||||
|
||||
__device__ __forceinline__ void umma_smem_write_k_sw128(
|
||||
bf16_t* smem_base, int row, int col, int hd, bf16_t val
|
||||
) {
|
||||
// K_SW128 BF16: atom shape (8, 1024), stride (1, 8)
|
||||
// For (128, HD) matrix in K-major:
|
||||
// row = K index (0..127), col = MN index (0..HD-1)
|
||||
// logical_offset = (k % 8) + (k / 8) * 8 * 128 + (mn % 8) * 8 + (mn / 8) * 1024
|
||||
// Hmm, this needs to be worked out properly.
|
||||
|
||||
// For K_SW128 BF16 atom: Shape<(8, 1024)>, Stride<(1, 8)>
|
||||
// In uint128_t: Shape<(1, 128)>, Stride<(1, 1)>
|
||||
// tile_to_shape for (128, HD):
|
||||
// K: 128 BF16 = 16 uint128_t → 16 tiles of the K atom (1 uint128_t each)
|
||||
// MN: HD BF16 = 8 uint128_t → 8 tiles of the MN atom (128 uint128_t each)
|
||||
|
||||
// Actually for K_SW128:
|
||||
// atom: (8, 1024) BF16, stride (1, 8)
|
||||
// For (128, 64) matrix:
|
||||
// logical_offset(k, mn) = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024 + (mn / 1024) * 8
|
||||
// = (k % 8) + mn * 8 + (k / 8) * 8192
|
||||
|
||||
// Wait, this doesn't seem right. The K_SW128 atom for BF16 is:
|
||||
// Shape<(8, 1024)> in elements, Stride<(1, 8)>
|
||||
// This means: 8 contiguous elements along K, then stride 8 to the next group
|
||||
// 1024 groups along MN
|
||||
|
||||
// For our (128, 64) matrix in K-major:
|
||||
// k ranges 0..127, mn ranges 0..63
|
||||
// Atom covers 8 K and 1024 MN. Need 16 K-tiles and 1 MN-tile.
|
||||
// logical_offset(k, mn) = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024
|
||||
// = (k % 8) + mn * 8 + (k / 8) * 8192
|
||||
|
||||
int logical = (row % 8) + col * 8 + (row / 8) * 8192;
|
||||
int swizzled = logical ^ ((((logical >> 3) ^ (logical >> 7)) & 0x7) << 3);
|
||||
smem_base[swizzled] = val;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// UMMA SMEM descriptor construction (correct format)
|
||||
// ==================================================================
|
||||
|
||||
enum class UmmaMajor { MN, K };
|
||||
enum class UmmaLayout { NONE = 0, MN_SW128 = 2, K_SW128 = 2 };
|
||||
|
||||
__device__ __forceinline__ uint64_t make_umma_desc(
|
||||
uint32_t smem_ptr,
|
||||
int dim_m, int dim_n,
|
||||
UmmaMajor major, UmmaLayout layout = UmmaLayout::MN_SW128
|
||||
/**
|
||||
* MN-major SWIZZLE_NONE descriptor for (128, HD) BF16 row-major.
|
||||
*/
|
||||
__device__ __forceinline__ uint64_t make_umma_desc_mn_none(
|
||||
uint32_t smem_ptr, int hd
|
||||
) {
|
||||
uint64_t desc = 0;
|
||||
|
||||
// start_address: bits [0,14) = smem_ptr >> 4 (in 16B units)
|
||||
desc |= (static_cast<uint64_t>(smem_ptr >> 4) & 0x3FFF);
|
||||
|
||||
// For SW128 layout:
|
||||
// The leading_byte_offset and stride_byte_offset describe the canonical
|
||||
// layout in uint128_t units.
|
||||
//
|
||||
// MN_SW128: atom (1024, 8) stride (1, 1024) in BF16 elements
|
||||
// = (128, 1) stride (1, 128) in uint128_t units
|
||||
// leading_byte_offset (K-dim stride) = 1024 BF16 * 2 / 16 = 128 uint128_t
|
||||
// stride_byte_offset (inter-tile MN stride) = same for 1 MN tile
|
||||
//
|
||||
// K_SW128: atom (8, 1024) stride (1, 8) in BF16 elements
|
||||
// = (1, 128) stride (1, 1) in uint128_t units
|
||||
// leading_byte_offset (MN-dim stride) = 8 BF16 * 2 / 16 = 1 uint128_t
|
||||
// stride_byte_offset (inter-tile K stride) = 8192 BF16 * 2 / 16 = 1024 uint128_t
|
||||
|
||||
if (major == UmmaMajor::MN) {
|
||||
// MN_SW128:
|
||||
// leading_byte_offset = 128 (uint128_t units) = 128 * 16 = 2048 bytes
|
||||
// stride_byte_offset = 128 (uint128_t units) — same for simple case
|
||||
int lbo = 128; // 1024 BF16 * 2 bytes / 16 = 128 uint128_t
|
||||
int sbo = 128;
|
||||
desc |= (static_cast<uint64_t>(lbo) & 0x3FFF) << 16;
|
||||
desc |= (static_cast<uint64_t>(sbo) & 0x3FFF) << 32;
|
||||
} else {
|
||||
// K_SW128:
|
||||
// leading_byte_offset = 1 (uint128_t units) = 8 BF16 per MN group
|
||||
// stride_byte_offset = 1024 (uint128_t units)
|
||||
int lbo = 1;
|
||||
int sbo = 1024;
|
||||
desc |= (static_cast<uint64_t>(lbo) & 0x3FFF) << 16;
|
||||
desc |= (static_cast<uint64_t>(sbo) & 0x3FFF) << 32;
|
||||
}
|
||||
|
||||
// version: bits [46,48) = 1
|
||||
desc |= (static_cast<uint64_t>(1) << 46);
|
||||
|
||||
// layout_type: bits [61,64) = 2 (SWIZZLE_128B)
|
||||
desc |= (static_cast<uint64_t>(2) << 61);
|
||||
|
||||
// LBO = 16 (row stride in uint128_t for 128 rows)
|
||||
desc |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
|
||||
// SBO = 128 (8-row group stride in uint128_t)
|
||||
desc |= (static_cast<uint64_t>(128) & 0x3FFF) << 32;
|
||||
desc |= (static_cast<uint64_t>(1) << 46); // version
|
||||
// layout_type = 0 (NONE)
|
||||
return desc;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// tcgen05.mma PTX wrappers
|
||||
// ==================================================================
|
||||
/**
|
||||
* K-major SWIZZLE_NONE descriptor for (128, HD) BF16 row-major.
|
||||
* The same SMEM data is used, but the descriptor tells the MMA to
|
||||
* interpret it as K-major (transposed view).
|
||||
* For K-major with stride (1, 128) BF16:
|
||||
* uint128_t: (16, 8) stride (1, 16)
|
||||
* logical_divide Tile(8, 2): ((8,2), (2,4)) strides ((1,8), (16,32))
|
||||
* LBO = stride_01 = 16, SBO = stride_11 = 32
|
||||
*/
|
||||
__device__ __forceinline__ uint64_t make_umma_desc_k_none(
|
||||
uint32_t smem_ptr, int hd
|
||||
) {
|
||||
uint64_t desc = 0;
|
||||
desc |= (static_cast<uint64_t>(smem_ptr >> 4) & 0x3FFF);
|
||||
desc |= (static_cast<uint64_t>(16) & 0x3FFF) << 16; // LBO
|
||||
desc |= (static_cast<uint64_t>(32) & 0x3FFF) << 32; // SBO
|
||||
desc |= (static_cast<uint64_t>(1) << 46); // version
|
||||
return desc;
|
||||
}
|
||||
|
||||
// tcgen05.mma SS: S←S (both SMEM → TMEM). One lane per warp.
|
||||
__device__ void umma_ss_f16(
|
||||
uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b,
|
||||
bool accumulate = false
|
||||
@@ -340,21 +87,19 @@ __device__ void umma_ss_f16(
|
||||
uint32_t idescE = static_cast<uint32_t>(desc_a >> 32);
|
||||
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
|
||||
uint32_t mask0 = 0, mask1 = 0, mask2 = 0, mask3 = 0;
|
||||
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p;\n\t"
|
||||
"}"
|
||||
:: "r"(tmem_c),
|
||||
"l"(desc_a), "l"(desc_b),
|
||||
"r"(idescE),
|
||||
"r"(scaleC_bits),
|
||||
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b),
|
||||
"r"(idescE), "r"(scaleC_bits),
|
||||
"r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)
|
||||
);
|
||||
}
|
||||
|
||||
// tcgen05.mma TS: T←S (TMEM A + SMEM B → TMEM C). One lane per warp.
|
||||
__device__ void umma_ts_f16(
|
||||
uint32_t tmem_c, uint32_t tmem_a, uint64_t desc_b,
|
||||
bool accumulate = true
|
||||
@@ -362,17 +107,14 @@ __device__ void umma_ts_f16(
|
||||
uint32_t idescE = static_cast<uint32_t>(desc_b >> 32);
|
||||
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
|
||||
uint32_t mask0 = 0, mask1 = 0, mask2 = 0, mask3 = 0;
|
||||
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p;\n\t"
|
||||
:: "r"(tmem_c),
|
||||
"r"(tmem_a),
|
||||
"l"(desc_b),
|
||||
"r"(idescE),
|
||||
"r"(scaleC_bits),
|
||||
"}"
|
||||
:: "r"(tmem_c), "r"(tmem_a), "l"(desc_b),
|
||||
"r"(idescE), "r"(scaleC_bits),
|
||||
"r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3)
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user