From df6220abaf5bdfafa5f08ccbfd254efd1d0c4e69 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 21:21:02 +0000 Subject: [PATCH] E5: Fold batch loop into native kernel grid (blockIdx.z) The 6-warp multi-tile kernel already supports batch natively via dim3 grid(1, n_h, batch). Removed Python for-loop for 4D input. Single kernel launch per layer for batched decode instead of batch_size launches. T>1 prefill still uses per-batch dispatch (E8 future work). --- dsv4/kernels/attention/production.py | 43 ++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 2c15bd98..7b86aa45 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -84,20 +84,39 @@ def dsv4_attention( # Handle batch dimension has_batch = q.dim() == 4 if has_batch: + # E5: Batch is handled natively by the kernel grid (blockIdx.z). + # The C API launch sets dim3 grid(1, n_h, batch) which processes + # all batch items in a single kernel launch. batch_size = q.shape[0] - # TODO (E5): fold into kernel grid instead of Python loop - outputs = [] + n_q, T, hd = q.shape[1], q.shape[2], q.shape[3] + scale = scale or (1.0 / math.sqrt(hd)) + + # Normalize K/V to (batch, n_kv, N, hd) + if k.dim() == 2: + k = k.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous() + v = v.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous() + elif k.dim() == 3: + k = k.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous() + v = v.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous() + + n_kv = k.shape[1] + q_per_kv = n_q // n_kv + assert n_q % n_kv == 0 + + if T == 1 and hd in (64, 128, 256, 512): + # Direct 4D dispatch — single kernel launch for all batch items + # GQA: expand K/V to n_h heads (handled inside fmha_multitile_decode_raw) + from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw + o_4d, _lse = fmha_multitile_decode_raw(q, k, v, scale, n_comp) + return o_4d + + # T>1 fallback: still need per-batch dispatch for prefill + output = torch.zeros(batch_size, n_q, T, hd, dtype=torch.bfloat16, device='cuda') for b in range(batch_size): - q_b = q[b] # (n_q_heads, T, hd) - k_b = k[b] if k.dim() == 4 else k # (n_kv_heads, N, hd) or (N, hd) - v_b = v[b] if v.dim() == 4 else v - sb_b = sink_bias[b] if sink_bias is not None and sink_bias.dim() == 2 else sink_bias - out_b = dsv4_attention( - q_b, k_b, v_b, scale=scale, swa_len=swa_len, - is_causal=is_causal, n_comp=n_comp, sink_bias=sb_b, - ) - outputs.append(out_b) - return torch.stack(outputs, dim=0) + out_b = dsv4_attention(q[b], k[b], v[b], scale=scale, swa_len=swa_len, + is_causal=is_causal, n_comp=n_comp, sink_bias=sink_bias) + output[b] = out_b + return output # 3D case: (n_q_heads, T, hd) n_q, T, hd = q.shape