SMEM-P: implement CUTLASS LLM fixes - dynamic frg_tile, local coordinate conversion
This commit is contained in:
@@ -315,11 +315,16 @@ class FmhaKernel:
|
||||
old_row_max = row_max
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
# Compute fragment tile size dynamically (must match value division)
|
||||
frg_tile_size = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
frg_layout = cute.make_layout(frg_tile_size)
|
||||
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, frg_layout)
|
||||
# Coordinate fragments for SMEM-P mapping (needed unconditionally for scoping)
|
||||
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(frg_tile))
|
||||
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, frg_layout)
|
||||
if self.use_smem_p:
|
||||
print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}")
|
||||
print(f"[SMEM-P CUTLASS] frg_tile_size: {frg_tile_size}, frg_layout: {frg_layout}")
|
||||
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
@@ -354,32 +359,19 @@ class FmhaKernel:
|
||||
n = qk_coord[1]
|
||||
|
||||
# Map to PV SMEM coordinate
|
||||
# Try transposed mapping (maybe PV expects P^T?)
|
||||
m0 = m % 16
|
||||
m1 = (m // 16) % 4
|
||||
m2 = m // 64
|
||||
pv_coord = ((n, m0), 0, (m1, m2), 0)
|
||||
# Convert to local coordinates (0-127) as sanity check
|
||||
m_local = m % 128
|
||||
n_local = n % 128
|
||||
|
||||
# Original mapping (likely wrong):
|
||||
# n0 = n % 16
|
||||
# n1 = (n // 16) % 4
|
||||
# n2 = n // 64
|
||||
# pv_coord = ((m, n0), 0, (n1, n2), 0)
|
||||
# Original mapping formula (should be correct for local coords)
|
||||
n0 = n_local % 16
|
||||
n1 = (n_local // 16) % 4
|
||||
n2 = n_local // 64
|
||||
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
|
||||
|
||||
# DEBUG: Write linear index as value: m*128 + n
|
||||
# This uniquely identifies each position
|
||||
linear_idx = m * 128 + n
|
||||
# Convert to Float32 (values 0-16383)
|
||||
pattern_val = Float32(linear_idx)
|
||||
p_val_bf16 = pattern_val.to(self.q_dtype)
|
||||
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
|
||||
# Try both tensor indexing AND manual offset for debugging
|
||||
# Write actual P value (not test pattern)
|
||||
p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
sP[pv_coord] = p_val_bf16 # Tensor indexing
|
||||
|
||||
# Also compute manual offset to verify
|
||||
# offset = cute.crd2idx(pv_coord, sP.layout)
|
||||
# (sP.iterator + offset) = p_val_bf16
|
||||
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
|
||||
Reference in New Issue
Block a user