fix: add is_causal to FmhaKernel __init__ signature
This commit is contained in:
@@ -16,7 +16,7 @@ import math
|
||||
|
||||
|
||||
class FmhaKernel:
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False):
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False):
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
@@ -32,7 +32,7 @@ class FmhaKernel:
|
||||
self.batch_size = batch_size
|
||||
self.normalize = normalize # D5a: False = emit un-normalized O + lse
|
||||
self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens
|
||||
self.is_causal = False # D4: causal mask (k_coord > m_coord) on SWA branch
|
||||
self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
Reference in New Issue
Block a user