WIP: P→SMEM write stub (zero fill, proper mapping TODO)
This commit is contained in:
@@ -232,10 +232,26 @@ class FmhaKernel:
|
||||
row_sum *= acc_scale
|
||||
minus_row_max = Float32(0.0) - row_max_safe
|
||||
|
||||
# Compute P = exp2(S * scale - row_max) and write to SMEM
|
||||
# First compute in FP32, convert to BF16, write to SMEM
|
||||
# TODO: proper SMEM write with P thread partition
|
||||
# For now, just arrive at softmax_done_bar to unblock MMA
|
||||
# Compute P = exp2(S * scale - row_max), convert to BF16, write to SMEM
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
|
||||
# Write P to SMEM using PV A-operand thread partition
|
||||
# Copy from softmax registers (QK partition) to SMEM (PV partition)
|
||||
# Each thread converts its P values to BF16 and stores to its SMEM slot
|
||||
rP_bf16_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype)
|
||||
# Map QK partitioned P values to PV partitioned SMEM slots
|
||||
# Simple approach: use cute.copy with the register and SMEM tensors
|
||||
# The P SMEM is partitioned by pv_thr, softmax threads fill their portion
|
||||
# For now: fill rP_bf16_reg from tTMEM_LOADrS (FP32→BF16 conversion)
|
||||
for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True):
|
||||
# TODO: proper element mapping from QK→PV partition
|
||||
rP_bf16_reg[j] = Float32(0.0)
|
||||
cute.copy(rP_bf16_reg, tCrP_smem)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
Reference in New Issue
Block a user