FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh): - Added sink_bias field to FmhaTmaMultiRowMultiTileParams - After KV tile loop, sink logit is included in online softmax rescale: new_max = max(running_max, sink_bias * scale) rescale existing O_unnorm and running_sum running_sum += exp(sink_bias * scale - new_max) No PV contribution from sink (D5c: single softmax) - C API: fmha_multitile_decode_launch now takes sink_bias_ptr - Python: fmha_multitile_decode_raw accepts attn_sink tensor single_shot_inference.py: - Full rewrite to use production kernel stack - mHC: uses dsv4.layers.mhc.mHCLayer (proper Sinkhorn-Knopp) - Projections: uses Nvfp4Linear (CuTeDSL GEMM) for q_a, q_b, kv, o_b - FMHA: 6-warp TMA multi-tile with sink bias (no SDPA fallback) - MoE: Nvfp4MoE + Nvfp4SharedExpert (no reference fallback) - Router: production dense/hash dispatch - Compressor/Indexer: reference dequant (not yet on tensor cores) - NO try/except fallbacks on production paths
197 lines
7.4 KiB
Python
197 lines
7.4 KiB
Python
"""DSV4 Blackwell Attention — Production kernel wrapper.
|
|
|
|
All attention goes through the 6-warp multi-tile FMHA kernel
|
|
(fmha_6warp_tma_multirow_multitile.cuh) via the C-API + ctypes bridge.
|
|
No CuTeDSL runtime dependency. No Python KV merge. No cudaDeviceSynchronize.
|
|
|
|
See README.md and ROADMAP.md for architecture, constraints, and next steps.
|
|
"""
|
|
import torch
|
|
import math
|
|
import logging
|
|
from typing import Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal: 6-warp multi-head multi-tile decode kernel
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _dsv4_attention_multitile(
|
|
q: torch.Tensor, # (n_q_heads, T, hd) BF16
|
|
k: torch.Tensor, # (n_kv_heads, N, hd) or (N, hd) BF16
|
|
v: torch.Tensor, # same shape as k
|
|
scale: float,
|
|
n_comp: int = 0,
|
|
sink_bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Multi-tile decode via TMA-based 6-warp FMHA kernel."""
|
|
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
|
|
|
n_q, T, hd = q.shape
|
|
n_kv = k.shape[0] if k.dim() == 3 else 1
|
|
N = k.shape[-2] if k.dim() == 3 else k.shape[0]
|
|
|
|
q_4d = q.unsqueeze(0).contiguous()
|
|
if k.dim() == 2:
|
|
k_4d = k.unsqueeze(0).unsqueeze(0).contiguous()
|
|
v_4d = v.unsqueeze(0).unsqueeze(0).transpose(-1, -2).contiguous()
|
|
else:
|
|
k_4d = k.unsqueeze(0).contiguous()
|
|
v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous()
|
|
|
|
o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias)
|
|
return o_4d.squeeze(0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def dsv4_attention(
|
|
q: torch.Tensor, # (batch, n_q_heads, T, hd) or (n_q_heads, T, hd)
|
|
k: torch.Tensor, # (batch, n_kv_heads, N, hd) or (n_kv_heads, N, hd) or (N, hd)
|
|
v: torch.Tensor, # same shape as k
|
|
scale: Optional[float] = None,
|
|
swa_len: Optional[int] = None,
|
|
is_causal: bool = False,
|
|
n_comp: int = 0,
|
|
sink_bias: Optional[torch.Tensor] = None, # (n_q_heads,) or (batch, n_q_heads)
|
|
) -> torch.Tensor:
|
|
"""Production DSV4 attention: MHA / MQA / GQA via 6-warp multi-tile kernel.
|
|
|
|
All head dims (64/128/256/512) and all sequence lengths handled by the
|
|
in-kernel multi-tile path. No Python KV merge, no CuTeDSL runtime.
|
|
|
|
For MQA/GQA: all Q heads sharing a KV head are packed into one kernel
|
|
launch via M dimension (q_per_kv * T rows). Each KV head is loaded once
|
|
per CTA, not per Q head.
|
|
|
|
Args:
|
|
q: (n_q_heads, T, hd) or (batch, n_q_heads, T, hd) BF16
|
|
k: (n_kv_heads, N, hd) or (N, hd) for MQA, or with batch dim BF16
|
|
v: same shape as k
|
|
scale: 1/sqrt(hd) if None
|
|
swa_len: sliding window length (currently unused by kernel — masks applied upstream)
|
|
is_causal: causal mask (currently unused by kernel — masks applied upstream)
|
|
n_comp: compressed KV length for D5c sink bias (reserved for future kernel integration)
|
|
sink_bias: per-head FP32 logit bias (reserved for future kernel integration)
|
|
|
|
Returns:
|
|
Same shape as q input (without batch: (n_q_heads, T, hd) BF16)
|
|
"""
|
|
# Handle batch dimension
|
|
has_batch = q.dim() == 4
|
|
if has_batch:
|
|
# E5: Batch is handled natively by the kernel grid (blockIdx.z).
|
|
# The C API launch sets dim3 grid(1, n_h, batch) which processes
|
|
# all batch items in a single kernel launch.
|
|
batch_size = q.shape[0]
|
|
n_q, T, hd = q.shape[1], q.shape[2], q.shape[3]
|
|
scale = scale or (1.0 / math.sqrt(hd))
|
|
|
|
# Normalize K/V to (batch, n_kv, N, hd)
|
|
if k.dim() == 2:
|
|
k = k.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous()
|
|
v = v.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1).contiguous()
|
|
elif k.dim() == 3:
|
|
k = k.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous()
|
|
v = v.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous()
|
|
|
|
n_kv = k.shape[1]
|
|
q_per_kv = n_q // n_kv
|
|
assert n_q % n_kv == 0
|
|
|
|
if T == 1 and hd in (64, 128, 256, 512):
|
|
# Direct 4D dispatch — single kernel launch for all batch items
|
|
# GQA: expand K/V to n_h heads (handled inside fmha_multitile_decode_raw)
|
|
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
|
o_4d, _lse = fmha_multitile_decode_raw(q, k, v, scale, n_comp)
|
|
return o_4d
|
|
|
|
# T>1 fallback: still need per-batch dispatch for prefill
|
|
output = torch.zeros(batch_size, n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
for b in range(batch_size):
|
|
out_b = dsv4_attention(q[b], k[b], v[b], scale=scale, swa_len=swa_len,
|
|
is_causal=is_causal, n_comp=n_comp, sink_bias=sink_bias)
|
|
output[b] = out_b
|
|
return output
|
|
|
|
# 3D case: (n_q_heads, T, hd)
|
|
n_q, T, hd = q.shape
|
|
scale = scale or (1.0 / math.sqrt(hd))
|
|
|
|
# Normalize K/V to (n_kv, N, hd)
|
|
if k.dim() == 2:
|
|
k = k.unsqueeze(0) # (1, N, hd) — MQA
|
|
if v.dim() == 2:
|
|
v = v.unsqueeze(0)
|
|
n_kv, N, _ = k.shape
|
|
|
|
# GQA ratio: each KV head serves (n_q // n_kv) Q heads
|
|
q_per_kv = n_q // n_kv
|
|
assert n_q % n_kv == 0, f"n_q_heads ({n_q}) must be divisible by n_kv_heads ({n_kv})"
|
|
|
|
# 6-warp multi-tile kernel handles all N, T=1 decode
|
|
if T == 1 and hd in (64, 128, 256, 512):
|
|
return _dsv4_attention_multitile(q, k, v, scale, n_comp, sink_bias)
|
|
|
|
# Prefill / T>1: head-packed dispatch per KV group
|
|
# TODO (E8): multi-CTA grid for T>1 prefill
|
|
output = torch.zeros(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
for kv_idx in range(n_kv):
|
|
q_start = kv_idx * q_per_kv
|
|
q_end = q_start + q_per_kv
|
|
q_group = q[q_start:q_end] # (q_per_kv, T, hd)
|
|
k_kv = k[kv_idx:kv_idx+1] # (1, N, hd)
|
|
v_kv = v[kv_idx:kv_idx+1]
|
|
|
|
o_group = _dsv4_attention_multitile(q_group, k_kv, v_kv, scale, n_comp, sink_bias)
|
|
output[q_start:q_end] = o_group
|
|
|
|
return output
|
|
|
|
|
|
def dsv4_attention_per_head(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
scale: Optional[float] = None,
|
|
swa_len: Optional[int] = None,
|
|
is_causal: bool = False,
|
|
n_comp: int = 0,
|
|
sink_bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Per-head launch variant — exact per-head sink bias support.
|
|
|
|
Use this when Q heads within a KV group have different sink biases
|
|
and exact results matter more than launch overhead. Otherwise prefer
|
|
dsv4_attention (head-packed).
|
|
"""
|
|
n_q, T, hd = q.shape
|
|
scale = scale or (1.0 / math.sqrt(hd))
|
|
|
|
if k.dim() == 2:
|
|
k = k.unsqueeze(0)
|
|
if v.dim() == 2:
|
|
v = v.unsqueeze(0)
|
|
n_kv, N, _ = k.shape
|
|
q_per_kv = n_q // n_kv
|
|
|
|
output = torch.zeros(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
for kv_idx in range(n_kv):
|
|
k_kv = k[kv_idx:kv_idx+1] # (1, N, hd)
|
|
v_kv = v[kv_idx:kv_idx+1]
|
|
|
|
for qi in range(q_per_kv):
|
|
q_idx = kv_idx * q_per_kv + qi
|
|
q_h = q[q_idx:q_idx+1] # (1, T, hd)
|
|
sb = sink_bias[q_idx:q_idx+1] if sink_bias is not None else None
|
|
|
|
o = _dsv4_attention_multitile(q_h, k_kv, v_kv, scale, n_comp, sb)
|
|
output[q_idx] = o
|
|
|
|
return output
|