diag: inv_row_sum=1.0 to test raw PV, n=128 only
This commit is contained in:
@@ -400,7 +400,8 @@ class FmhaV3StageCMulti:
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
# DIAG: use 1.0 instead of 1/row_sum to test raw PV output
|
||||
inv_row_sum = Float32(1.0)
|
||||
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
@@ -447,7 +448,7 @@ class FmhaV3StageCMulti:
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
for n in [128, 256]:
|
||||
for n in [128]:
|
||||
torch.manual_seed(42)
|
||||
m, hd = 128, HEAD_DIM
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
Reference in New Issue
Block a user