SMEM-P: fix BF16 value creation (use constant)
This commit is contained in:
@@ -360,15 +360,17 @@ class FmhaKernel:
|
||||
# We need to map each of these 128 values to SMEM
|
||||
|
||||
# For testing: write a simple pattern to verify mapping works
|
||||
# Use thread/warp index as mock coordinate
|
||||
test_m = warp_idx % 128
|
||||
test_n = tidx % 128
|
||||
# Each thread writes to different coordinate for testing
|
||||
# Use thread-relative simple coordinates
|
||||
thread_offset = tidx % 16 # 0-15
|
||||
test_m = thread_offset
|
||||
test_n = thread_offset
|
||||
test_coord = qk_to_pv_coord(test_m, test_n)
|
||||
|
||||
# Write test value
|
||||
test_val = BFloat16(float(warp_idx) * 0.01 + float(tidx) * 0.001)
|
||||
# Write constant test value (0.5)
|
||||
test_val = BFloat16(0.5)
|
||||
sP[test_coord] = test_val
|
||||
print(f"[SMEM-P CUTLASS] Thread ({warp_idx},{tidx}) wrote test to coord {test_coord}")
|
||||
print(f"[SMEM-P CUTLASS] Thread wrote test to coord {test_coord}")
|
||||
|
||||
# TODO: Implement full 128-value mapping
|
||||
# Need to:
|
||||
|
||||
Reference in New Issue
Block a user