Files
nvfp4-megamoe-kernel/tests/unit/test_pv_ss_b64.cu

89 lines
3.1 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Minimal test: PV SS MMA with B=(64,16) BLOCK_MN=64.
* Tests if BLOCK_MN=64 is valid for the UMMA B descriptor.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
constexpr int HD = 64, BLOCK_MN = 128;
__global__ void __launch_bounds__(128)
test_pv_ss_b64()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sP = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sP + 128 * 16) + 127) & ~(uintptr_t)127);
// Fill P: (128, 16) canonical, row 0 = all 1.0
constexpr int CORES_MN = 16, CORES_K = 2;
for (int i = tid; i < 128 * 16; i += 128) sP[i] = 0;
__syncthreads();
for (int c = tid; c < 16; c += 128) {
int ck = c / 8, lc = c % 8;
sP[ck * CORES_MN * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(1.0f);
}
__syncthreads();
// Fill V: (64, 16) canonical with BLOCK_MN=64
// CORES_MN=8, CORES_K=2
// B[d, r]: g_mn=d/8, g_k=r/8, llr=d%8, lc=r%8
for (int i = tid; i < 64 * 16; i += 128) sV[i] = 0;
__syncthreads();
for (int d = tid; d < HD; d += 128) {
for (int r = 0; r < 16; r++) {
int g_mn = d / 8, g_k = r / 8;
int llr = d % 8, lc = r % 8;
sV[g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc] = f32_to_bf16(2.0f);
}
}
__syncthreads();
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// PV SS MMA: A=(128,16) BLOCK_MN=128, B=(64,16) BLOCK_MN=64
// C = A × B^T = (128, 64). idesc: MMA_M=8, MMA_N=4
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sP), BLOCK_MN);
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), HD);
uint32_t idesc = make_idesc(BLOCK_MN, HD);
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read first 8 cols of C
if (wid == 0) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) { printf("C[0,0..7]: "); for(int c=0;c<8;c++) printf("%.1f ", tmp[c]); printf("\n"); }
}
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== PV SS MMA B=(64,16) BLOCK_MN=64 ===\n");
int smem = (4+16 + 128*16*2 + 64*16*2 + 256 + 127) & ~127;
printf("SMEM: %d bytes\n", smem);
test_pv_ss_b64<<<1, 128, smem>>>();
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
return 0;
}