Fix: add self.n_kv_tiles to __init__
This commit is contained in:
@@ -56,6 +56,7 @@ class FmhaV3StageCMulti:
|
||||
def __init__(self, s_k=128, scale_softmax=None):
|
||||
# s_k MUST equal actual sequence length n.
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
Reference in New Issue
Block a user