test: remove sanity check (zeroing loop overwrites), fix verify offsets
This commit is contained in:
@@ -46,7 +46,6 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Load Q and K into SMEM in canonical layout
|
||||
// Zero all first
|
||||
for (int i = tid; i < 128 * hd; i += N_WARPS * 32) {
|
||||
sQ[i] = 0;
|
||||
@@ -54,19 +53,6 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Sanity check: write to sQ[0]
|
||||
if (tid == 0) {
|
||||
uint16_t one_bf16 = f32_to_bf16(1.0f);
|
||||
sQ[0] = one_bf16;
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
uint16_t val = sQ[0];
|
||||
float fval = bf16_to_f32(val);
|
||||
s_out[250] = fval; // Should be 1.0
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write Q (1, hd) to row 0 of sQ in canonical layout
|
||||
for (int d = tid; d < hd; d += N_WARPS * 32) {
|
||||
int core_k = d / 8, local_c = d % 8;
|
||||
@@ -84,13 +70,14 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
|
||||
__syncthreads();
|
||||
|
||||
// Verify SMEM data for first K-tile (columns 0-15)
|
||||
// In canonical layout, Q[d] for row 0 is at core_k * 16 * 64 + local_c
|
||||
if (tid == 0) {
|
||||
// Q row 0, d=0..7: core(0,0) at offset 0, local_r=0, local_c=d
|
||||
// Q row 0, d=0..7: core_k=0, local_c=d → sQ[d]
|
||||
for (int d = 0; d < 8; d++)
|
||||
s_out[200+d] = bf16_to_f32(sQ[d]); // core(0,0), row 0, col d
|
||||
// Q row 0, d=16..23: core(0,2) at offset 2*1024 = 2048, local_r=0, local_c=d-16
|
||||
s_out[200+d] = bf16_to_f32(sQ[d]);
|
||||
// Q row 0, d=8..15: core_k=1, local_c=d-8 → sQ[1024 + d-8]
|
||||
for (int d = 0; d < 8; d++)
|
||||
s_out[208+d] = bf16_to_f32(sQ[2048 + d]); // core(0,2), row 0, col 0..7
|
||||
s_out[208+d] = bf16_to_f32(sQ[1024 + d]);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
||||
|
||||
Reference in New Issue
Block a user