test: B=2.0 to understand TS MMA scale factor
This commit is contained in:
@@ -38,7 +38,7 @@ test_mma_ts(float* o_out)
|
||||
int r = i / 16, c = i % 16;
|
||||
int ck = c / 8, lc = c % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(1.0f);
|
||||
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(2.0f); // Use 2.0 to distinguish from A=1.0
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user