SMEM-P: fix BF16 value creation (use constant)

This commit is contained in:
2026-05-23 19:33:29 +00:00
parent 58639aa634
commit e118ad967d

View File

@@ -360,15 +360,17 @@ class FmhaKernel:
# We need to map each of these 128 values to SMEM
# For testing: write a simple pattern to verify mapping works
# Use thread/warp index as mock coordinate
test_m = warp_idx % 128
test_n = tidx % 128
# Each thread writes to different coordinate for testing
# Use thread-relative simple coordinates
thread_offset = tidx % 16 # 0-15
test_m = thread_offset
test_n = thread_offset
test_coord = qk_to_pv_coord(test_m, test_n)
# Write test value
test_val = BFloat16(float(warp_idx) * 0.01 + float(tidx) * 0.001)
# Write constant test value (0.5)
test_val = BFloat16(0.5)
sP[test_coord] = test_val
print(f"[SMEM-P CUTLASS] Thread ({warp_idx},{tidx}) wrote test to coord {test_coord}")
print(f"[SMEM-P CUTLASS] Thread wrote test to coord {test_coord}")
# TODO: Implement full 128-value mapping
# Need to: