diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 4a6ec02e..bf4fe9ed 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -356,11 +356,6 @@ class FmhaV3StageCMulti: acc_scale = Float32(0.0) row_sum *= acc_scale - # DEBUG: print acc_scale at kt=1 (only from thread 0) - if kt == 1 and sfw_idx == 0: - cute.printf("O_RESCALE kt=1: old_max=%f new_max=%f acc_scale=%f row_sum=%f\n", - old_row_max, row_max_safe, acc_scale, row_sum) - # Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum, # store BF16 P through the FP32-backed register bridge. rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) @@ -379,23 +374,11 @@ class FmhaV3StageCMulti: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() - # DEBUG: print row_max and row_sum after each tile - if sfw_idx == 0: - cute.printf("SOFTMAX kt=%d: row_max=%f row_sum=%f\n", kt, row_max_safe, row_sum) - # === Per-tile O rescale: O *= acc_scale for kt > 0 === - # Uses the SAME 2D register tensor pattern as final normalize - # to ensure paired atoms work correctly. + # TEMP: disabled for diagnosis + if False and kt > 0: if kt > 0: - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) tTMEM_LOADtO_i = cute.make_tensor( tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout, @@ -404,15 +387,12 @@ class FmhaV3StageCMulti: tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout, ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - # DEBUG: print first element of O before and after rescale - if i == 0 and sfw_idx == 0: - cute.printf("O_RESCALE before: O[0]=%f acc_scale=%f\n", tTMrO_i[0], acc_scale) - for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * acc_scale - if i == 0 and sfw_idx == 0: - cute.printf("O_RESCALE after: O[0]=%f\n", tTMrO_i[0]) - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) + cute.arch.fence_view_async_tmem_load() + for k in cutlass.range(cute.size(tTMrO), vectorize=True): + tTMrO[k] = tTMrO[k] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store() si_handle.release() @@ -469,7 +449,7 @@ class FmhaV3StageCMulti: def test(): torch.manual_seed(42) - for n in [128, 256]: + for n in [256]: torch.manual_seed(42) m, hd = 128, HEAD_DIM q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') @@ -482,45 +462,9 @@ def test(): kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn = qf @ kf.T * scale - - # Per-tile reference breakdown - n_tiles = n // 128 - for t in range(n_tiles): - tile_attn = attn[:, t*128:(t+1)*128] - tile_max = tile_attn.max(dim=-1, keepdim=True).values - print(f' Ref tile {t}: max_score={tile_max.max().item():.4f} ' - f'mean_score={tile_attn.mean().item():.4f}', flush=True) - - # Full softmax reference attn = torch.softmax(attn, dim=-1) ref = attn @ v.float() - # Also compute reference with online softmax to verify acc_scale - if n_tiles > 1: - # Simulate online softmax: process tiles sequentially - online_o = torch.zeros(m, hd, dtype=torch.float32) - online_row_max = torch.full((m, 1), float('-inf')) - online_row_sum = torch.zeros(m, 1) - for t in range(n_tiles): - tile_scores = qf @ kf[t*128:(t+1)*128].T * scale # (128, 128) - old_max = online_row_max.clone() - new_max = torch.max(old_max, tile_scores.max(dim=-1, keepdim=True).values) - acc_scale_ref = torch.exp(old_max - new_max) - acc_scale_ref[old_max == float('-inf')] = 0.0 - online_o = online_o * acc_scale_ref - tile_softmax = torch.exp(tile_scores - new_max) - online_row_sum = online_row_sum * acc_scale_ref + tile_softmax.sum(dim=-1, keepdim=True) - online_o = online_o + tile_softmax @ v[t*128:(t+1)*128].float() - online_row_max = new_max - print(f' Online ref tile {t}: acc_scale_mean={acc_scale_ref[old_max != float("-inf")].mean().item():.6f} ' - f'row_max_mean={online_row_max.mean().item():.4f} ' - f'row_sum_mean={online_row_sum.mean().item():.4f}', flush=True) - online_o = online_o / online_row_sum - online_cos = torch.nn.functional.cosine_similarity( - online_o.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) - ).item() - print(f' Online softmax ref vs full softmax ref: cos {online_cos:.6f}', flush=True) - mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) @@ -542,20 +486,13 @@ def test(): out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) ).item() max_abs = (out - ref).abs().max().item() + n_tiles = n // 128 print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): ' f'cos {cos:.6f} max_abs {max_abs:.4f} ' f'{"PASS" if cos >= 0.99 else "FAIL"}') if cos < 0.99: - # Detailed per-row error analysis - row_cos = torch.nn.functional.cosine_similarity( - out.unsqueeze(1), ref.unsqueeze(1), dim=-1) - worst_rows = row_cos.topk(5, largest=False) - print(f' Worst 5 rows: {[(i.item(), c.item()) for i, c in zip(worst_rows.indices, worst_rows.values)]}') - print(f' out[0,:8]={out[0,:8].tolist()}') - print(f' ref[0,:8]={ref[0,:8].tolist()}') - print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]') - print(f' ref range: [{ref.min().item():.4f}, {ref.max().item():.4f}]') - print(f' mean abs error: {(out - ref).abs().mean().item():.6f}') + print(f' out[0,:4]={out[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') if __name__ == '__main__':