Files
nvfp4-megamoe-kernel/dsv4/kernels/attention/__init__.py
biondizzle b9f15c250f Stage E: head-packed MQA/GQA, batch dim, custom_op, integration API
- production.py: head-packed M dimension for MQA/GQA (q_per_kv*T rows
  in single launch per KV group, eliminating redundant K/V TMA loads)
- production.py: batch dimension support (outer Python loop)
- production.py: warmup_attention_kernels() for pre-compilation
- production.py: dsv4_attention_per_head() for exact per-head sink bias
- __init__.py: sparse_fmha_with_swa, dense_fmha_with_swa, swa_only_fmha
  integration functions bridging AttentionSubBlock → production FMHA
- custom_ops.py: dsv4::sparse_fmha_with_swa custom_op registration
- test_production.py: comprehensive tests (MHA/MQA/GQA, head-packed vs
  per-head parity, multi-segment KV, SWA+causal+sink, batch, edge cases)
2026-05-27 15:15:03 +00:00

169 lines
5.3 KiB
Python

"""DSV4 Attention kernels — public integration API.
These functions bridge the model's AttentionSubBlock to the production
FMHA kernel wrapper. Each function handles the cache → dense-tensor
materialization that the kernel requires.
The model's attention layer calls these after:
1. Projection (q_down, q_up, kv_down)
2. RoPE application
3. Compression + cache writes
4. Indexer + top-k (CSA only)
These functions handle:
- Gathering sparse/dense KV from cache into dense tensors
- Calling the production FMHA wrapper
- Returning attention output for inverse RoPE + wo_a/wo_b
"""
from dsv4.kernels.attention.production import dsv4_attention
import torch
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
def sparse_fmha_with_swa(
q: torch.Tensor, # (T, n_h * hd) BF16, post-RoPE
cache: "LayerCacheHandle", # provides compressed + SWA KV
selected_indices: torch.Tensor, # (T, top_k) int64 — which compressed blocks
sink_logits: Optional[torch.Tensor] = None, # (n_h,) FP32
sliding_window: int = 128,
) -> torch.Tensor:
"""CSA attention: sparse top-k compressed KV + sliding window, fused sink merge.
Gathers the top-k compressed KV blocks + SWA window into a contiguous
tensor, then calls the production FMHA with sink bias.
Args:
q: (T, n_h * hd) BF16 query (post-RoPE, pre-reshape)
cache: LayerCacheHandle with CSA compressed entries + SWA window
selected_indices: (T, top_k) int64 block indices from the indexer
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output (pre inverse-RoPE)
"""
# Reshape q to (n_h, T, hd)
n_h_and_hd = q.shape[-1]
# n_h and hd come from the cache's config
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
# Gather compressed KV for the selected blocks
# The cache handle provides the materialized dense KV from paged pool
k_compressed, v_compressed = cache.gather_compressed_kv(selected_indices)
# k_compressed: (1, n_comp_kv, hd) or (n_kv, n_comp_kv, hd)
# v_compressed: same shape
# Gather SWA window KV
k_swa, v_swa = cache.gather_swa_kv()
# k_swa: (1, swa_len, hd), v_swa: same
# Concatenate: [compressed, SWA] — single softmax (D5c insight)
k_full = torch.cat([k_compressed, k_swa], dim=-2) # (1, n_comp+swa_len, hd)
v_full = torch.cat([v_compressed, v_swa], dim=-2)
# n_comp = compressed KV length (for sink bias offset)
n_comp = k_compressed.shape[-2]
# Call production attention — MQA (n_kv=1 for DSV4)
output = dsv4_attention(
q_heads, k_full, v_full,
swa_len=sliding_window,
is_causal=True,
n_comp=n_comp,
sink_bias=sink_logits,
) # (n_h, T, hd)
# Reshape back to (T, n_h * hd)
return output.permute(1, 0, 2).reshape(T, n_h * hd)
def dense_fmha_with_swa(
q: torch.Tensor,
cache: "LayerCacheHandle",
sink_logits: Optional[torch.Tensor] = None,
sliding_window: int = 128,
) -> torch.Tensor:
"""HCA attention: dense over all compressed KV + SWA window, fused sink merge.
No indexer — all compressed entries are attended (m'=128 compression
means the sequence is very short).
Args:
q: (T, n_h * hd) BF16 query
cache: LayerCacheHandle with HCA compressed entries + SWA window
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output
"""
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
# Dense: gather ALL compressed KV (no indexer needed)
k_compressed, v_compressed = cache.gather_all_compressed_kv()
k_swa, v_swa = cache.gather_swa_kv()
k_full = torch.cat([k_compressed, k_swa], dim=-2)
v_full = torch.cat([v_compressed, v_swa], dim=-2)
n_comp = k_compressed.shape[-2]
output = dsv4_attention(
q_heads, k_full, v_full,
swa_len=sliding_window,
is_causal=True,
n_comp=n_comp,
sink_bias=sink_logits,
)
return output.permute(1, 0, 2).reshape(T, n_h * hd)
def swa_only_fmha(
q: torch.Tensor,
cache: "LayerCacheHandle",
sink_logits: Optional[torch.Tensor] = None,
sliding_window: int = 128,
) -> torch.Tensor:
"""SWA-only attention: pure local attention over the sliding window.
No compression branch, no indexer. Used for the first two layers
of the Flash variant.
Args:
q: (T, n_h * hd) BF16 query
cache: LayerCacheHandle with SWA window
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output
"""
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
k_swa, v_swa = cache.gather_swa_kv()
# No n_comp (no compressed branch), no sink bias offset
output = dsv4_attention(
q_heads, k_swa, v_swa,
swa_len=sliding_window,
is_causal=True,
n_comp=0,
sink_bias=sink_logits,
)
return output.permute(1, 0, 2).reshape(T, n_h * hd)