fix: SMEM calc in decode test
This commit is contained in:
@@ -197,6 +197,8 @@ int main() {
|
||||
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
constexpr int TILE_SZ = 128 * MMA_K_BF16;
|
||||
constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
|
||||
size_t smem_off = 8 + 128*4 + 128*4;
|
||||
smem_off = ((smem_off + 127) & ~(size_t)127);
|
||||
smem_off += TILE_SZ + TILE_SZ;
|
||||
|
||||
Reference in New Issue
Block a user