fix: add s_k param to FmhaV3StageC, use self.s_k for V FMHA reconstruction
This commit is contained in:
@@ -16,8 +16,8 @@ import math
|
||||
HEAD_DIM = 64
|
||||
|
||||
class FmhaV3StageC:
|
||||
def __init__(self, scale_softmax=None):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
def __init__(self, s_k=128, scale_softmax=None):
|
||||
self.s_k = s_k self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
@@ -76,8 +76,8 @@ class FmhaV3StageC:
|
||||
v_fmha = cute.make_tensor(
|
||||
v.iterator,
|
||||
cute.make_layout(
|
||||
(HEAD_DIM, 128, 1),
|
||||
stride=(1, HEAD_DIM, HEAD_DIM * 128),
|
||||
(HEAD_DIM, self.s_k, 1),
|
||||
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
|
||||
),
|
||||
)
|
||||
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
||||
|
||||
Reference in New Issue
Block a user