WIP: tiled copy for P→SMEM (zero fill)
This commit is contained in:
@@ -200,10 +200,11 @@ class FmhaKernel:
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# P → SMEM: use PV A-operand partition for SMEM write
|
||||
p_s = cute.slice_(p_smem_s,(None,None,None,0))
|
||||
tCrP_smem = pv_thr.partition_A(sP) # PV thread → SMEM partition for P (A operand)
|
||||
tCrP_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype)
|
||||
# P → SMEM copy setup
|
||||
p_copy_atom = cute.make_copy_atom(cute.nvgpu.copy.AutoCopyOp(), self.q_dtype)
|
||||
tiled_p_copy = cute.make_tiled_copy(p_copy_atom, tCrP_smem.layout, tCrP_reg.layout, tidx)
|
||||
tPS_sP = tiled_p_copy.get_slice(sfw_idx).partition_D(tCrP_smem)
|
||||
tPS_rP = tiled_p_copy.get_slice(sfw_idx).partition_S(tCrP_reg)
|
||||
|
||||
# Online softmax state
|
||||
row_max = -Float32.inf
|
||||
@@ -240,17 +241,10 @@ class FmhaKernel:
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
|
||||
# Write P to SMEM using PV A-operand thread partition
|
||||
# Copy from softmax registers (QK partition) to SMEM (PV partition)
|
||||
# Each thread converts its P values to BF16 and stores to its SMEM slot
|
||||
rP_bf16_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype)
|
||||
# Map QK partitioned P values to PV partitioned SMEM slots
|
||||
# Simple approach: use cute.copy with the register and SMEM tensors
|
||||
# The P SMEM is partitioned by pv_thr, softmax threads fill their portion
|
||||
# For now: fill rP_bf16_reg from tTMEM_LOADrS (FP32→BF16 conversion)
|
||||
for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True):
|
||||
# TODO: proper element mapping from QK→PV partition
|
||||
rP_bf16_reg[j] = BFloat16(0.0)
|
||||
cute.copy(tCrP_smem, rP_bf16_reg)
|
||||
# TODO: proper element mapping from QK→PV partition
|
||||
for j in cutlass.range(cute.size(tCrP_reg), vectorize=True):
|
||||
tCrP_reg[j] = BFloat16(0.0)
|
||||
cute.copy(tiled_p_copy, tPS_rP, tPS_sP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
si_handle.release()
|
||||
|
||||
Reference in New Issue
Block a user