diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 6fbb4581..ad6a2828 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -360,15 +360,17 @@ class FmhaKernel: # We need to map each of these 128 values to SMEM # For testing: write a simple pattern to verify mapping works - # Use thread/warp index as mock coordinate - test_m = warp_idx % 128 - test_n = tidx % 128 + # Each thread writes to different coordinate for testing + # Use thread-relative simple coordinates + thread_offset = tidx % 16 # 0-15 + test_m = thread_offset + test_n = thread_offset test_coord = qk_to_pv_coord(test_m, test_n) - # Write test value - test_val = BFloat16(float(warp_idx) * 0.01 + float(tidx) * 0.001) + # Write constant test value (0.5) + test_val = BFloat16(0.5) sP[test_coord] = test_val - print(f"[SMEM-P CUTLASS] Thread ({warp_idx},{tidx}) wrote test to coord {test_coord}") + print(f"[SMEM-P CUTLASS] Thread wrote test to coord {test_coord}") # TODO: Implement full 128-value mapping # Need to: