E5: Fold batch loop into native kernel grid (blockIdx.z)
The 6-warp multi-tile kernel already supports batch natively via dim3 grid(1, n_h, batch). Removed Python for-loop for 4D input. Single kernel launch per layer for batched decode instead of batch_size launches. T>1 prefill still uses per-batch dispatch (E8 future work).
This commit is contained in:
@@ -84,20 +84,39 @@ def dsv4_attention(
|
||||
# Handle batch dimension
|
||||
has_batch = q.dim() == 4
|
||||
if has_batch:
|
||||
# E5: Batch is handled natively by the kernel grid (blockIdx.z).
|
||||
# The C API launch sets dim3 grid(1, n_h, batch) which processes
|
||||
# all batch items in a single kernel launch.
|
||||
batch_size = q.shape[0]
|
||||
# TODO (E5): fold into kernel grid instead of Python loop
|
||||
outputs = []
|
||||
n_q, T, hd = q.shape[1], q.shape[2], q.shape[3]
|
||||
scale = scale or (1.0 / math.sqrt(hd))
|
||||
|
||||
# Normalize K/V to (batch, n_kv, N, hd)
|
||||
if k.dim() == 2:
|
||||
k = k.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous()
|
||||
v = v.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous()
|
||||
elif k.dim() == 3:
|
||||
k = k.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous()
|
||||
v = v.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous()
|
||||
|
||||
n_kv = k.shape[1]
|
||||
q_per_kv = n_q // n_kv
|
||||
assert n_q % n_kv == 0
|
||||
|
||||
if T == 1 and hd in (64, 128, 256, 512):
|
||||
# Direct 4D dispatch — single kernel launch for all batch items
|
||||
# GQA: expand K/V to n_h heads (handled inside fmha_multitile_decode_raw)
|
||||
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
||||
o_4d, _lse = fmha_multitile_decode_raw(q, k, v, scale, n_comp)
|
||||
return o_4d
|
||||
|
||||
# T>1 fallback: still need per-batch dispatch for prefill
|
||||
output = torch.zeros(batch_size, n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
for b in range(batch_size):
|
||||
q_b = q[b] # (n_q_heads, T, hd)
|
||||
k_b = k[b] if k.dim() == 4 else k # (n_kv_heads, N, hd) or (N, hd)
|
||||
v_b = v[b] if v.dim() == 4 else v
|
||||
sb_b = sink_bias[b] if sink_bias is not None and sink_bias.dim() == 2 else sink_bias
|
||||
out_b = dsv4_attention(
|
||||
q_b, k_b, v_b, scale=scale, swa_len=swa_len,
|
||||
is_causal=is_causal, n_comp=n_comp, sink_bias=sb_b,
|
||||
)
|
||||
outputs.append(out_b)
|
||||
return torch.stack(outputs, dim=0)
|
||||
out_b = dsv4_attention(q[b], k[b], v[b], scale=scale, swa_len=swa_len,
|
||||
is_causal=is_causal, n_comp=n_comp, sink_bias=sink_bias)
|
||||
output[b] = out_b
|
||||
return output
|
||||
|
||||
# 3D case: (n_q_heads, T, hd)
|
||||
n_q, T, hd = q.shape
|
||||
|
||||
Reference in New Issue
Block a user