[Bugfix] Fixing division by zero in triton_attn if query_heads/kv_heads > 16 (#23424)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
This commit is contained in:
Burkhard Ringlein
2025-09-03 17:01:09 +02:00
committed by GitHub
parent 4ba0c587ba
commit 6d80ae83e1

View File

@@ -674,7 +674,8 @@ def unified_attention(
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = q.shape[2]
BLOCK_M = 16
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(
num_queries_per_kv)
BLOCK_Q = BLOCK_M // num_queries_per_kv
# Ideally we would launch with kernel with: