WIP: confirmed row_sum is wrong (5.5 vs correct 29.22 for row 0)

The packed f32x2 reduction SHOULD sum all 128 exp2 P values but gives
a result ~5.3x too small. Need to debug inside the kernel with print
statements to see what values the reduction is actually summing.

Unnormalized P@V is perfect (cosine 0.999998). row_max is correct
(because P is correct). The bug is specifically in row_sum computation.
This commit is contained in:
2026-05-21 19:16:15 +00:00
parent 8eb569e31c
commit cae87fd744

View File

@@ -71,8 +71,8 @@ class FmhaV3Softmax:
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# s_k = cute.size(v, mode=[0]) # BROKEN in @cute.jit
# FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major
# # s_k hardcoded # BROKEN in @cute.jit
# FMHA-style V: reconstruct as (HEAD_DIM, 128, 1) MN-major
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
@@ -373,24 +373,9 @@ class FmhaV3Softmax:
row_sum = row_sum + tile_sum
# --- C9: Final normalization (scalar inv_row_sum) ---
# --- C9: SKIPPED for debug (no normalization) ---
pv_done_bar.arrive_and_wait()
inv_row_sum = cutlass.Float32(1.0) / row_sum
tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
for i in range(o_col_tiles):
tTMrO_i_ = tTMrO_final[None, i]
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO_final.shape[0]))
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# O is unnormalized in TMEM. Use standard epilogue_tma_store with identity.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
@@ -416,7 +401,8 @@ def test():
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda")
qf = q[:,:,0].float(); kf = k[:,:,0].float()
attn = qf @ kf.T / math.sqrt(hd)
ref = torch.softmax(attn, dim=-1) @ v.float()
P = torch.exp(attn - attn.max(dim=-1, keepdim=True)[0])
ref = P @ v.float() # unnormalized P@V
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
@@ -432,7 +418,7 @@ def test():
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True)
print(f"FMHA softmax (no C9 norm) n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True)
if __name__ == "__main__":
test()