Files
nvfp4-megamoe-kernel/dsv4/kernels/attention/__init__.py
biondizzle 4336de9372 attention/: Clean up folder, archive backups, add detailed status headers
What changed:
- Moved fmha_backup_pre_epilog.py, fmha_backup_v2.py, fmha_smem_acc.py to archive/
- Deleted fmha.py.backup (git has history)
- Added detailed heredoc headers to ALL files documenting:
  * WHAT WORKS and WHAT'S BROKEN
  * WHY each limitation exists (CuTeDSL toolchain gaps)
  * KEY INSIGHTS FOR NVIDIA (what CuTeDSL is missing)
  * What each file unblocks if fixed

File status:
  fmha.py                 — CuTeDSL FMHA, cos 0.999998, D1.5 workaround
  fmha_common.cuh         — Raw CUDA shared defs (BF16, TMEM ops)
  fmha_sm100.cuh          — Raw CUDA reference, cos 0.999999
  fmha_epilogue_sm100.cuh — Raw CUDA TMEM epilogue, HANGS (needs debug)
  fmha_sm100_launch.cu    — PyTorch binding (JIT broken, nvcc works)
  production.py           — CuTeDSL production wrapper (partial)
  archive/                — Historical backups with explanation headers
2026-05-28 07:01:33 +00:00

181 lines
5.9 KiB
Python

"""DSV4 Attention kernels — public integration API.
====================================================================
STATUS: SKELETON — not yet connected to model
====================================================================
These functions define the API that AttentionSubBlock will call.
They're correct in structure but depend on:
1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.)
2. The production FMHA wrapper supporting sink_bias and n_comp
3. Custom op registration for torch.compile compatibility
See ROADMAP.md Priority 5 for the full Stage E checklist.
====================================================================
These functions bridge the model's AttentionSubBlock to the production
FMHA kernel wrapper. Each function handles the cache → dense-tensor
materialization that the kernel requires.
The model's attention layer calls these after:
1. Projection (q_down, q_up, kv_down)
2. RoPE application
3. Compression + cache writes
4. Indexer + top-k (CSA only)
These functions handle:
- Gathering sparse/dense KV from cache into dense tensors
- Calling the production FMHA wrapper
- Returning attention output for inverse RoPE + wo_a/wo_b
"""
from dsv4.kernels.attention.production import dsv4_attention
import torch
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
def sparse_fmha_with_swa(
q: torch.Tensor, # (T, n_h * hd) BF16, post-RoPE
cache: "LayerCacheHandle", # provides compressed + SWA KV
selected_indices: torch.Tensor, # (T, top_k) int64 — which compressed blocks
sink_logits: Optional[torch.Tensor] = None, # (n_h,) FP32
sliding_window: int = 128,
) -> torch.Tensor:
"""CSA attention: sparse top-k compressed KV + sliding window, fused sink merge.
Gathers the top-k compressed KV blocks + SWA window into a contiguous
tensor, then calls the production FMHA with sink bias.
Args:
q: (T, n_h * hd) BF16 query (post-RoPE, pre-reshape)
cache: LayerCacheHandle with CSA compressed entries + SWA window
selected_indices: (T, top_k) int64 block indices from the indexer
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output (pre inverse-RoPE)
"""
# Reshape q to (n_h, T, hd)
n_h_and_hd = q.shape[-1]
# n_h and hd come from the cache's config
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
# Gather compressed KV for the selected blocks
# The cache handle provides the materialized dense KV from paged pool
k_compressed, v_compressed = cache.gather_compressed_kv(selected_indices)
# k_compressed: (1, n_comp_kv, hd) or (n_kv, n_comp_kv, hd)
# v_compressed: same shape
# Gather SWA window KV
k_swa, v_swa = cache.gather_swa_kv()
# k_swa: (1, swa_len, hd), v_swa: same
# Concatenate: [compressed, SWA] — single softmax (D5c insight)
k_full = torch.cat([k_compressed, k_swa], dim=-2) # (1, n_comp+swa_len, hd)
v_full = torch.cat([v_compressed, v_swa], dim=-2)
# n_comp = compressed KV length (for sink bias offset)
n_comp = k_compressed.shape[-2]
# Call production attention — MQA (n_kv=1 for DSV4)
output = dsv4_attention(
q_heads, k_full, v_full,
swa_len=sliding_window,
is_causal=True,
n_comp=n_comp,
sink_bias=sink_logits,
) # (n_h, T, hd)
# Reshape back to (T, n_h * hd)
return output.permute(1, 0, 2).reshape(T, n_h * hd)
def dense_fmha_with_swa(
q: torch.Tensor,
cache: "LayerCacheHandle",
sink_logits: Optional[torch.Tensor] = None,
sliding_window: int = 128,
) -> torch.Tensor:
"""HCA attention: dense over all compressed KV + SWA window, fused sink merge.
No indexer — all compressed entries are attended (m'=128 compression
means the sequence is very short).
Args:
q: (T, n_h * hd) BF16 query
cache: LayerCacheHandle with HCA compressed entries + SWA window
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output
"""
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
# Dense: gather ALL compressed KV (no indexer needed)
k_compressed, v_compressed = cache.gather_all_compressed_kv()
k_swa, v_swa = cache.gather_swa_kv()
k_full = torch.cat([k_compressed, k_swa], dim=-2)
v_full = torch.cat([v_compressed, v_swa], dim=-2)
n_comp = k_compressed.shape[-2]
output = dsv4_attention(
q_heads, k_full, v_full,
swa_len=sliding_window,
is_causal=True,
n_comp=n_comp,
sink_bias=sink_logits,
)
return output.permute(1, 0, 2).reshape(T, n_h * hd)
def swa_only_fmha(
q: torch.Tensor,
cache: "LayerCacheHandle",
sink_logits: Optional[torch.Tensor] = None,
sliding_window: int = 128,
) -> torch.Tensor:
"""SWA-only attention: pure local attention over the sliding window.
No compression branch, no indexer. Used for the first two layers
of the Flash variant.
Args:
q: (T, n_h * hd) BF16 query
cache: LayerCacheHandle with SWA window
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output
"""
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
k_swa, v_swa = cache.gather_swa_kv()
# No n_comp (no compressed branch), no sink bias offset
output = dsv4_attention(
q_heads, k_swa, v_swa,
swa_len=sliding_window,
is_causal=True,
n_comp=0,
sink_bias=sink_logits,
)
return output.permute(1, 0, 2).reshape(T, n_h * hd)