fix: add s_k param to FmhaV3StageC, use self.s_k for V FMHA reconstruction

This commit is contained in:
2026-05-22 10:19:49 +00:00
parent b80a1ab083
commit 1982cc4d39

View File

@@ -16,8 +16,8 @@ import math
HEAD_DIM = 64
class FmhaV3StageC:
def __init__(self, scale_softmax=None):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
def __init__(self, s_k=128, scale_softmax=None):
self.s_k = s_k self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
@@ -76,8 +76,8 @@ class FmhaV3StageC:
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, 128, 1),
stride=(1, HEAD_DIM, HEAD_DIM * 128),
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()