What changed: - Moved fmha_backup_pre_epilog.py, fmha_backup_v2.py, fmha_smem_acc.py to archive/ - Deleted fmha.py.backup (git has history) - Added detailed heredoc headers to ALL files documenting: * WHAT WORKS and WHAT'S BROKEN * WHY each limitation exists (CuTeDSL toolchain gaps) * KEY INSIGHTS FOR NVIDIA (what CuTeDSL is missing) * What each file unblocks if fixed File status: fmha.py — CuTeDSL FMHA, cos 0.999998, D1.5 workaround fmha_common.cuh — Raw CUDA shared defs (BF16, TMEM ops) fmha_sm100.cuh — Raw CUDA reference, cos 0.999999 fmha_epilogue_sm100.cuh — Raw CUDA TMEM epilogue, HANGS (needs debug) fmha_sm100_launch.cu — PyTorch binding (JIT broken, nvcc works) production.py — CuTeDSL production wrapper (partial) archive/ — Historical backups with explanation headers
181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
"""DSV4 Attention kernels — public integration API.
|
|
|
|
====================================================================
|
|
STATUS: SKELETON — not yet connected to model
|
|
====================================================================
|
|
These functions define the API that AttentionSubBlock will call.
|
|
They're correct in structure but depend on:
|
|
1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.)
|
|
2. The production FMHA wrapper supporting sink_bias and n_comp
|
|
3. Custom op registration for torch.compile compatibility
|
|
|
|
See ROADMAP.md Priority 5 for the full Stage E checklist.
|
|
====================================================================
|
|
|
|
These functions bridge the model's AttentionSubBlock to the production
|
|
FMHA kernel wrapper. Each function handles the cache → dense-tensor
|
|
materialization that the kernel requires.
|
|
|
|
The model's attention layer calls these after:
|
|
1. Projection (q_down, q_up, kv_down)
|
|
2. RoPE application
|
|
3. Compression + cache writes
|
|
4. Indexer + top-k (CSA only)
|
|
|
|
These functions handle:
|
|
- Gathering sparse/dense KV from cache into dense tensors
|
|
- Calling the production FMHA wrapper
|
|
- Returning attention output for inverse RoPE + wo_a/wo_b
|
|
"""
|
|
from dsv4.kernels.attention.production import dsv4_attention
|
|
import torch
|
|
from typing import Optional, TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from dsv4.cache.handle import LayerCacheHandle
|
|
|
|
|
|
def sparse_fmha_with_swa(
|
|
q: torch.Tensor, # (T, n_h * hd) BF16, post-RoPE
|
|
cache: "LayerCacheHandle", # provides compressed + SWA KV
|
|
selected_indices: torch.Tensor, # (T, top_k) int64 — which compressed blocks
|
|
sink_logits: Optional[torch.Tensor] = None, # (n_h,) FP32
|
|
sliding_window: int = 128,
|
|
) -> torch.Tensor:
|
|
"""CSA attention: sparse top-k compressed KV + sliding window, fused sink merge.
|
|
|
|
Gathers the top-k compressed KV blocks + SWA window into a contiguous
|
|
tensor, then calls the production FMHA with sink bias.
|
|
|
|
Args:
|
|
q: (T, n_h * hd) BF16 query (post-RoPE, pre-reshape)
|
|
cache: LayerCacheHandle with CSA compressed entries + SWA window
|
|
selected_indices: (T, top_k) int64 block indices from the indexer
|
|
sink_logits: (n_h,) FP32 per-head sink bias
|
|
sliding_window: SWA window length
|
|
|
|
Returns:
|
|
(T, n_h * hd) BF16 attention output (pre inverse-RoPE)
|
|
"""
|
|
# Reshape q to (n_h, T, hd)
|
|
n_h_and_hd = q.shape[-1]
|
|
# n_h and hd come from the cache's config
|
|
n_h = cache.num_query_heads
|
|
hd = cache.head_dim
|
|
T = q.shape[0]
|
|
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
|
|
|
# Gather compressed KV for the selected blocks
|
|
# The cache handle provides the materialized dense KV from paged pool
|
|
k_compressed, v_compressed = cache.gather_compressed_kv(selected_indices)
|
|
# k_compressed: (1, n_comp_kv, hd) or (n_kv, n_comp_kv, hd)
|
|
# v_compressed: same shape
|
|
|
|
# Gather SWA window KV
|
|
k_swa, v_swa = cache.gather_swa_kv()
|
|
# k_swa: (1, swa_len, hd), v_swa: same
|
|
|
|
# Concatenate: [compressed, SWA] — single softmax (D5c insight)
|
|
k_full = torch.cat([k_compressed, k_swa], dim=-2) # (1, n_comp+swa_len, hd)
|
|
v_full = torch.cat([v_compressed, v_swa], dim=-2)
|
|
|
|
# n_comp = compressed KV length (for sink bias offset)
|
|
n_comp = k_compressed.shape[-2]
|
|
|
|
# Call production attention — MQA (n_kv=1 for DSV4)
|
|
output = dsv4_attention(
|
|
q_heads, k_full, v_full,
|
|
swa_len=sliding_window,
|
|
is_causal=True,
|
|
n_comp=n_comp,
|
|
sink_bias=sink_logits,
|
|
) # (n_h, T, hd)
|
|
|
|
# Reshape back to (T, n_h * hd)
|
|
return output.permute(1, 0, 2).reshape(T, n_h * hd)
|
|
|
|
|
|
def dense_fmha_with_swa(
|
|
q: torch.Tensor,
|
|
cache: "LayerCacheHandle",
|
|
sink_logits: Optional[torch.Tensor] = None,
|
|
sliding_window: int = 128,
|
|
) -> torch.Tensor:
|
|
"""HCA attention: dense over all compressed KV + SWA window, fused sink merge.
|
|
|
|
No indexer — all compressed entries are attended (m'=128 compression
|
|
means the sequence is very short).
|
|
|
|
Args:
|
|
q: (T, n_h * hd) BF16 query
|
|
cache: LayerCacheHandle with HCA compressed entries + SWA window
|
|
sink_logits: (n_h,) FP32 per-head sink bias
|
|
sliding_window: SWA window length
|
|
|
|
Returns:
|
|
(T, n_h * hd) BF16 attention output
|
|
"""
|
|
n_h = cache.num_query_heads
|
|
hd = cache.head_dim
|
|
T = q.shape[0]
|
|
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
|
|
|
|
# Dense: gather ALL compressed KV (no indexer needed)
|
|
k_compressed, v_compressed = cache.gather_all_compressed_kv()
|
|
|
|
k_swa, v_swa = cache.gather_swa_kv()
|
|
|
|
k_full = torch.cat([k_compressed, k_swa], dim=-2)
|
|
v_full = torch.cat([v_compressed, v_swa], dim=-2)
|
|
|
|
n_comp = k_compressed.shape[-2]
|
|
|
|
output = dsv4_attention(
|
|
q_heads, k_full, v_full,
|
|
swa_len=sliding_window,
|
|
is_causal=True,
|
|
n_comp=n_comp,
|
|
sink_bias=sink_logits,
|
|
)
|
|
|
|
return output.permute(1, 0, 2).reshape(T, n_h * hd)
|
|
|
|
|
|
def swa_only_fmha(
|
|
q: torch.Tensor,
|
|
cache: "LayerCacheHandle",
|
|
sink_logits: Optional[torch.Tensor] = None,
|
|
sliding_window: int = 128,
|
|
) -> torch.Tensor:
|
|
"""SWA-only attention: pure local attention over the sliding window.
|
|
|
|
No compression branch, no indexer. Used for the first two layers
|
|
of the Flash variant.
|
|
|
|
Args:
|
|
q: (T, n_h * hd) BF16 query
|
|
cache: LayerCacheHandle with SWA window
|
|
sink_logits: (n_h,) FP32 per-head sink bias
|
|
sliding_window: SWA window length
|
|
|
|
Returns:
|
|
(T, n_h * hd) BF16 attention output
|
|
"""
|
|
n_h = cache.num_query_heads
|
|
hd = cache.head_dim
|
|
T = q.shape[0]
|
|
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
|
|
|
|
k_swa, v_swa = cache.gather_swa_kv()
|
|
|
|
# No n_comp (no compressed branch), no sink bias offset
|
|
output = dsv4_attention(
|
|
q_heads, k_swa, v_swa,
|
|
swa_len=sliding_window,
|
|
is_causal=True,
|
|
n_comp=0,
|
|
sink_bias=sink_logits,
|
|
)
|
|
|
|
return output.permute(1, 0, 2).reshape(T, n_h * hd)
|