Files
nvfp4-megamoe-kernel/tests/unit/test_pv_ss.cu

107 lines
4.0 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Minimal test: PV via SS MMA with just 1 K-tile.
* P(128,16) × V(16,16) → O(128,16)
* A = P K-tile (128, 16) from SMEM, B = V K-tile (16, 16) from SMEM
* Both in canonical K-major layout.
* Tests if SS MMA can handle BLOCK_MN_A=128, BLOCK_MN_B=16.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 16, BLOCK_MN = 128;
__global__ void __launch_bounds__(128)
test_pv_ss()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
// A = P (128, 16) canonical — all 1.0 for row 0, 0 elsewhere
bf16_t* sP = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
// B = V (16, 16) canonical — all 2.0
bf16_t* sV = (bf16_t*)(((uintptr_t)(sP + 128 * 16) + 127) & ~(uintptr_t)127);
// Fill P: (128, 16) canonical, row 0 = all 1.0
constexpr int CORES_MN = 128 / 8; // 16
constexpr int CORES_K = 16 / 8; // 2
for (int i = tid; i < 128 * 16; i += 128) sP[i] = 0;
__syncthreads();
// Row 0: core_mn=0, local_r=0
for (int c = tid; c < 16; c += 128) {
int ck = c / 8, lc = c % 8;
int dst_idx = ck * CORES_MN * 64 + 0 * 64 + 0 * 8 + lc;
sP[dst_idx] = f32_to_bf16(1.0f);
}
// Fill V: (16, 16) canonical, all 2.0
for (int i = tid; i < 16 * 16; i += 128) sV[i] = 0;
__syncthreads();
for (int i = tid; i < 16 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(2.0f);
}
__syncthreads();
// TMEM alloc: 128 columns
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// PV SS MMA: A=(128,16) BLOCK_MN=128, B=(16,16) BLOCK_MN=16
// C = A × B^T = (128, 16) → TMEM
// Expected: C[0, 0..15] = 16.0 (1.0 * 2.0 * 8 core values... actually sum of 16 ones * 2.0 = 32.0, MMA scale 0.5 → 16.0)
printf("Before PV: tid=%d\n", tid);
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sP), BLOCK_MN);
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(BLOCK_MN, HD); // MMA_M=8, MMA_N=1
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
printf("After PV: tid=%d\n", tid);
// Read O from TMEM
if (wid == 0) {
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; // Don't apply scale correction yet
}
if (lane == 0) {
printf("O[0,0..15]: ");
for (int d=0;d<HD;d++) printf("%.2f ", o_vals[d]);
printf("(raw MMA, expect 16.0 with 0.5 scale or 32.0 with 1.0 scale)\n");
}
}
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== PV SS MMA Test (A=128x16, B=16x16) ===\n");
int smem = (4+16 + 128*16*2 + 16*16*2 + 256 + 127) & ~127;
test_pv_ss<<<1, 128, smem>>>();
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
printf("PASS\n");
return 0;
}