fix: BFloat16 not Float32 for bf16 reg
This commit is contained in:
@@ -249,7 +249,7 @@ class FmhaKernel:
|
||||
# For now: fill rP_bf16_reg from tTMEM_LOADrS (FP32→BF16 conversion)
|
||||
for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True):
|
||||
# TODO: proper element mapping from QK→PV partition
|
||||
rP_bf16_reg[j] = Float32(0.0)
|
||||
rP_bf16_reg[j] = BFloat16(0.0)
|
||||
cute.copy(rP_bf16_reg, tCrP_smem)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user