fix: BFloat16 not Float32 for bf16 reg

This commit is contained in:
2026-05-23 03:50:09 +00:00
parent 748873a58c
commit 99f13cf52e

View File

@@ -249,7 +249,7 @@ class FmhaKernel:
# For now: fill rP_bf16_reg from tTMEM_LOADrS (FP32→BF16 conversion)
for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True):
# TODO: proper element mapping from QK→PV partition
rP_bf16_reg[j] = Float32(0.0)
rP_bf16_reg[j] = BFloat16(0.0)
cute.copy(rP_bf16_reg, tCrP_smem)
cute.arch.fence_proxy("async.shared", space="cta")