diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index d9b841e4..0cf1ca0f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -249,7 +249,7 @@ class FmhaKernel: # For now: fill rP_bf16_reg from tTMEM_LOADrS (FP32→BF16 conversion) for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True): # TODO: proper element mapping from QK→PV partition - rP_bf16_reg[j] = Float32(0.0) + rP_bf16_reg[j] = BFloat16(0.0) cute.copy(rP_bf16_reg, tCrP_smem) cute.arch.fence_proxy("async.shared", space="cta")