diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index f16a7c14..93706a73 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -16,8 +16,8 @@ import math HEAD_DIM = 64 class FmhaV3StageC: - def __init__(self, scale_softmax=None): - self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + def __init__(self, s_k=128, scale_softmax=None): + self.s_k = s_k self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE @@ -76,8 +76,8 @@ class FmhaV3StageC: v_fmha = cute.make_tensor( v.iterator, cute.make_layout( - (HEAD_DIM, 128, 1), - stride=(1, HEAD_DIM, HEAD_DIM * 128), + (HEAD_DIM, self.s_k, 1), + stride=(1, HEAD_DIM, HEAD_DIM * self.s_k), ), ) self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()