D5a: Add normalize flag + LSE output
- normalize=True (default): O = softmax(P) @ V (existing behavior) - normalize=False: O = P @ V (un-normalized) + lse = log(row_sum) + row_max - LSE tensor passed as optional parameter - Test includes D5a normalize=False verification with LSE comparison - Cleaned up SMEM-P debug prints and broken make_tiled_copy_C code - hd=64 TMEM-P regression: cos 0.973 PASS
This commit is contained in:
@@ -87,7 +87,7 @@ class FmhaKernel:
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream):
|
||||
def __call__(self, q, k, v, c, stream, lse=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
@@ -111,10 +111,10 @@ class FmhaKernel:
|
||||
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
|
||||
epi_s = cute.select(self.c_smem_s,mode=[0,1])
|
||||
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile):
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
@@ -264,19 +264,10 @@ class FmhaKernel:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True)
|
||||
# Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout.
|
||||
# Softmax warps have P values in QK C-fragment layout (same as rP_bf16).
|
||||
# This copy writes those values to sP which has PV A-operand SMEM layout.
|
||||
smem_copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
|
||||
sP_2d = cute.group_modes(sP, 0, 3)
|
||||
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM)
|
||||
# P SMEM copy atoms: SMEM-P (TBD)
|
||||
# make_tiled_copy_C gives rank mismatch (QK C-fragment has 4 modes,
|
||||
# PV A-operand SMEM has 3 modes). Need proper layout-aware copy.
|
||||
# For now, SMEM-P path zero-fills sP. TMEM-P (hd<=64) works correctly.
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -349,34 +340,13 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: Use QK C-fragment layout for source (not TMEM layout)
|
||||
# rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch
|
||||
# Create view with QK C-fragment layout (tStS0.layout)
|
||||
rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread
|
||||
rP_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout)
|
||||
|
||||
# Partition source with QK layout
|
||||
tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk)
|
||||
|
||||
# Debug shapes
|
||||
print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM")
|
||||
print(f"[SMEM-P PROPER] rP_qk shape: {cute.shape(rP_qk)}, layout: QK C-fragment")
|
||||
print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}")
|
||||
print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
|
||||
|
||||
# Attempt copy with correct layout
|
||||
try:
|
||||
cute.copy(tiled_smem_copy, tSMEM_CPYrP_qk, tSMEM_CPYsP)
|
||||
print(f"[SMEM-P PROPER] Copy succeeded with QK C-fragment layout")
|
||||
except Exception as e:
|
||||
print(f"[SMEM-P PROPER] Copy failed: {e}")
|
||||
# Fallback to stub for now
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = BFloat16(0.0)
|
||||
print(f"[SMEM-P PROPER] Used fallback stub")
|
||||
|
||||
# SMEM-P: zero-fill sP stub (proper layout-aware copy TBD)
|
||||
# make_tiled_copy_C gives rank mismatch (4 vs 3).
|
||||
# Need proper P register→SMEM copy that respects QK C-fragment layout
|
||||
# and PV A-operand SMEM layout. For now, TMEM-P (hd<=64) works.
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = self.q_dtype(0)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
@@ -430,7 +400,9 @@ class FmhaKernel:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
# D5a: When normalize=False, skip normalization (emit un-normalized O + lse)
|
||||
if const_expr(self.normalize):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
@@ -450,8 +422,9 @@ class FmhaKernel:
|
||||
)
|
||||
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
|
||||
if const_expr(self.normalize):
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
@@ -470,5 +443,15 @@ class FmhaKernel:
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
# D5a: Write LSE (log-softmax) when normalize=False
|
||||
# lse = log(row_sum) + row_max (row_max in scaled domain)
|
||||
# Only thread 0 of the epilogue warps writes LSE for this tile.
|
||||
# For M=1 decode: one lse value per query row.
|
||||
if const_expr(not self.normalize):
|
||||
if mLSE is not None:
|
||||
if sfw_idx == 0:
|
||||
lse_val = cute.math.log(row_sum, fastmath=True) + row_max_safe
|
||||
mLSE[None, None, 0] = lse_val
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
@@ -98,6 +98,8 @@ def test():
|
||||
cos64 = test_head_dim(64, 128)
|
||||
|
||||
# hd=256: single PV tile at MMA instruction max
|
||||
# NOTE: SMEM-P path is a stub (zero-fill), so hd>64 will FAIL
|
||||
# until the proper P register→SMEM copy is implemented.
|
||||
print("\n--- HEAD_DIM=256 (single PV tile) ---")
|
||||
cos256 = test_head_dim(256, 128)
|
||||
|
||||
@@ -105,11 +107,70 @@ def test():
|
||||
print("\n--- HEAD_DIM=512 (2 PV tiles) ---")
|
||||
cos512 = test_head_dim(512, 128)
|
||||
|
||||
# D5a: normalize=False with LSE output
|
||||
print("\n--- D5a: normalize=False, LSE output (hd=64) ---")
|
||||
hd = 64; n_kv = 128; m = 128
|
||||
torch.manual_seed(42)
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda')
|
||||
c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
# FP32 reference
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn = qf @ kf.T * scale
|
||||
# Compute reference LSE: log(sum(exp(attn - max)))
|
||||
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(attn - attn_max)
|
||||
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
||||
ref_lse = torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1) # (m,)
|
||||
ref_attn = attn_exp / attn_sum
|
||||
ref = ref_attn @ v.float()
|
||||
# Un-normalized reference: O_unnorm = sum(P * V) (no 1/row_sum)
|
||||
ref_unnorm = attn_exp @ v.float() # un-normalized
|
||||
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n_kv, normalize=False)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
|
||||
print('Compiling normalize=False kernel...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out_unnorm = c_tile[:, :, 0].float()
|
||||
lse_out = lse_tensor[0, 0, 0].item()
|
||||
|
||||
# Verify un-normalized output matches reference
|
||||
cos_unnorm = torch.nn.functional.cosine_similarity(
|
||||
out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
|
||||
).item()
|
||||
# Verify LSE matches reference (first row)
|
||||
ref_lse_val = ref_lse[0].item()
|
||||
lse_err = abs(lse_out - ref_lse_val)
|
||||
print(f' Un-norm O: cos {cos_unnorm:.6f} (should be >= 0.97)')
|
||||
print(f' LSE: kernel={lse_out:.6f} ref={ref_lse_val:.6f} err={lse_err:.6f}')
|
||||
|
||||
# Summary
|
||||
print("\n=== Summary ===")
|
||||
print(f"hd=64, n=128: cos={cos64:.6f} {'PASS' if cos64 >= 0.97 else 'FAIL'}")
|
||||
print(f"hd=256, n=128: cos={cos256:.6f} {'PASS' if cos256 >= 0.97 else 'FAIL'}")
|
||||
print(f"hd=512, n=128: cos={cos512:.6f} {'PASS' if cos512 >= 0.97 else 'FAIL'}")
|
||||
print(f"D5a unnorm: cos={cos_unnorm:.6f} lse_err={lse_err:.6f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user