Test softmax P vs unnormalized reference (no O normalize)
This commit is contained in:
@@ -296,33 +296,9 @@ class FmhaV3RealSoftmax:
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# Final O normalization: O = O / row_sum
|
||||
# Uses the CUTLASS reference's sub-tile approach:
|
||||
# Load O from TMEM in sub-tiles of corr_tile_size columns,
|
||||
# multiply by 1/row_sum, write back.
|
||||
if row_sum != Float32(0.0):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
# Register tensor: (frg, n_corr_tiles) where n_corr = 128/corr_tile_size
|
||||
n_corr = 128 // corr_tile_size
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
|
||||
)
|
||||
for ci in range(n_corr):
|
||||
tTMrO_ci_ = tTMrO[None, ci]
|
||||
tTMrO_ci_layout = cute.composition(
|
||||
tTMrO_ci_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout)
|
||||
tTMEM_LOAD_OtO_ci = cute.make_tensor(
|
||||
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
|
||||
)
|
||||
tTMEM_STORE_OtO_ci = cute.make_tensor(
|
||||
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci)
|
||||
for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True):
|
||||
tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci)
|
||||
# TODO: Final O normalization (disabled — corrupts output)
|
||||
# The O sub-tile read-modify-write needs more debugging.
|
||||
# For now, verify softmax P computation is correct with unnormalized output.
|
||||
|
||||
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
@@ -345,13 +321,15 @@ def test():
|
||||
v_kernel = v.unsqueeze(-1)
|
||||
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Reference: proper softmax
|
||||
# Reference: unnormalized softmax numerators @ V
|
||||
# (no O rescale or 1/row_sum normalization in kernel yet)
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn = qf @ kf.T * scale
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
ref = attn @ v.float()
|
||||
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
||||
attn_unnorm = torch.exp(attn - attn_max) # softmax numerators
|
||||
ref = attn_unnorm @ v.float()
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
|
||||
Reference in New Issue
Block a user