diag: inv_row_sum=1.0 to test raw PV, n=128 only

This commit is contained in:
2026-05-23 01:17:14 +00:00
parent 2b93b10199
commit 690fd77e6c

View File

@@ -400,7 +400,8 @@ class FmhaV3StageCMulti:
final_o_bar.arrive_and_wait()
# === Final O normalization: O *= 1/row_sum ===
inv_row_sum = Float32(1.0) / row_sum
# DIAG: use 1.0 instead of 1/row_sum to test raw PV output
inv_row_sum = Float32(1.0)
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
@@ -447,7 +448,7 @@ class FmhaV3StageCMulti:
def test():
torch.manual_seed(42)
for n in [128, 256]:
for n in [128]:
torch.manual_seed(42)
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')