Fix triton compilation issue (#3984)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -415,7 +415,11 @@ def attn_fwd(
|
||||
return
|
||||
|
||||
is_mqa = hq != hk
|
||||
off_h_k = off_h_q % hk if is_mqa else off_h_q
|
||||
if is_mqa: # noqa: SIM108
|
||||
off_h_k = off_h_q % hk
|
||||
else:
|
||||
off_h_k = off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
|
||||
Reference in New Issue
Block a user