E5: Fold batch loop into native kernel grid (blockIdx.z)

The 6-warp multi-tile kernel already supports batch natively via
dim3 grid(1, n_h, batch). Removed Python for-loop for 4D input.
Single kernel launch per layer for batched decode instead of
batch_size launches.

T>1 prefill still uses per-batch dispatch (E8 future work).
This commit is contained in:
2026-05-30 21:21:02 +00:00
parent e162a2d112
commit df6220abaf

View File

@@ -84,20 +84,39 @@ def dsv4_attention(
# Handle batch dimension
has_batch = q.dim() == 4
if has_batch:
# E5: Batch is handled natively by the kernel grid (blockIdx.z).
# The C API launch sets dim3 grid(1, n_h, batch) which processes
# all batch items in a single kernel launch.
batch_size = q.shape[0]
# TODO (E5): fold into kernel grid instead of Python loop
outputs = []
n_q, T, hd = q.shape[1], q.shape[2], q.shape[3]
scale = scale or (1.0 / math.sqrt(hd))
# Normalize K/V to (batch, n_kv, N, hd)
if k.dim() == 2:
k = k.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous()
v = v.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous()
elif k.dim() == 3:
k = k.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous()
v = v.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous()
n_kv = k.shape[1]
q_per_kv = n_q // n_kv
assert n_q % n_kv == 0
if T == 1 and hd in (64, 128, 256, 512):
# Direct 4D dispatch — single kernel launch for all batch items
# GQA: expand K/V to n_h heads (handled inside fmha_multitile_decode_raw)
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
o_4d, _lse = fmha_multitile_decode_raw(q, k, v, scale, n_comp)
return o_4d
# T>1 fallback: still need per-batch dispatch for prefill
output = torch.zeros(batch_size, n_q, T, hd, dtype=torch.bfloat16, device='cuda')
for b in range(batch_size):
q_b = q[b] # (n_q_heads, T, hd)
k_b = k[b] if k.dim() == 4 else k # (n_kv_heads, N, hd) or (N, hd)
v_b = v[b] if v.dim() == 4 else v
sb_b = sink_bias[b] if sink_bias is not None and sink_bias.dim() == 2 else sink_bias
out_b = dsv4_attention(
q_b, k_b, v_b, scale=scale, swa_len=swa_len,
is_causal=is_causal, n_comp=n_comp, sink_bias=sb_b,
)
outputs.append(out_b)
return torch.stack(outputs, dim=0)
out_b = dsv4_attention(q[b], k[b], v[b], scale=scale, swa_len=swa_len,
is_causal=is_causal, n_comp=n_comp, sink_bias=sink_bias)
output[b] = out_b
return output
# 3D case: (n_q_heads, T, hd)
n_q, T, hd = q.shape