Files
nvfp4-megamoe-kernel/dsv4/layers/attention.py
biondizzle 4453d7475a Fix layer construction: match existing API signatures, add RMSNorm impl
- Nvfp4GroupedLinear: (n_local_groups, heads_per_group, head_dim, o_lora_rank)
- mHCLayer: hidden_dim, t_max_sinkhorn (not hidden_size, sinkhorn_iters)
- RMSNorm: PyTorch reference implementation (BF16, cudagraph-safe)
- Verified: all 43 Flash + 61 Pro layers construct cleanly
- All projection shapes validated against architecture spec
2026-05-21 23:31:58 +00:00

246 lines
10 KiB
Python

"""DSV4 attention sub-block — CSA / HCA / SWA variants.
The constructor decides which variant this instance is by reading the
LayerSpec passed in. After construction, forward() is fixed — no runtime
branching on attention type, so torch.compile can fold the dispatch
completely.
What this class is responsible for:
- Holding the projection weights (Q-down/up, KV-down, output groups).
- Calling Nvfp4Linear for the projections.
- Running RoPE (partial, last rope_dim dimensions).
- Calling the compressor (CSA or HCA) to write into the KV cache.
- For CSA only: calling the indexer + top-k to select sparse entries.
- Calling the FMHA kernel with the right K/V source.
- Calling inverse RoPE on the output.
- Calling Nvfp4GroupedLinear (wo_a) + Nvfp4Linear (wo_b).
What this class does NOT touch:
- mHC pre/post — done by TransformerLayer.
- Pre-norm — done by TransformerLayer.
- KV cache allocation / eviction — done by the cache handle.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from dsv4.layers.linear import Nvfp4Linear
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
from dsv4.model.layer_schedule import AttentionType, LayerSpec
if TYPE_CHECKING:
from dsv4.model.config import DSV4Config
from dsv4.cache.paged_cache import LayerCacheHandle
class AttentionSubBlock:
def __init__(self, config: "DSV4Config", spec: LayerSpec):
self.config = config
self.spec = spec
self.layer_idx = spec.layer_idx
self.attn_type = spec.attn # frozen for this layer's lifetime
# ---- Projections (all NVFP4) ----
# Query path is low-rank: hidden -> dc -> head_dim * num_heads.
self.q_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.query_compression_dim,
)
self.q_up = Nvfp4Linear(
in_features=config.query_compression_dim,
out_features=config.head_dim * config.num_query_heads,
)
# KV projection — single head shared across queries (MQA).
# CSA needs two streams (a, b) for the overlapping compression;
# HCA needs one. SWA needs raw KV with no compression.
if self.attn_type == AttentionType.CSA:
# 2 * head_dim for (Ca, Cb), 2 * head_dim for compression weights (Za, Zb)
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=4 * config.head_dim,
)
elif self.attn_type == AttentionType.HCA:
# head_dim for C, head_dim for Z
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=2 * config.head_dim,
)
else: # SWA
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.head_dim,
)
# ---- Indexer (CSA only) ----
if self.attn_type == AttentionType.CSA:
# Indexer queries share the q_down latent c_Q with main queries.
# Up-projection: c_Q -> indexer_num_heads * indexer_head_dim.
self.indexer_q_up = Nvfp4Linear(
in_features=config.query_compression_dim,
out_features=config.indexer_num_heads * config.indexer_head_dim,
)
# Per-head weight vector w_h: hidden -> indexer_num_heads.
self.indexer_head_weights = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.indexer_num_heads,
max_num_tokens=1, # scalar per head, not per-token projection
)
# ---- Output projection ----
# wo_a: grouped, splits the n_heads outputs into n_groups and projects
# each group from (heads_per_group * head_dim) to o_lora_rank.
heads_per_group = config.num_query_heads // config.num_output_groups
self.wo_a = Nvfp4GroupedLinear(
n_local_groups=config.num_output_groups,
heads_per_group=heads_per_group,
head_dim=config.head_dim,
o_lora_rank=config.output_group_dim,
)
# wo_b: dense, concatenated group outputs back to hidden_size.
self.wo_b = Nvfp4Linear(
in_features=config.output_group_dim * config.num_output_groups,
out_features=config.hidden_size,
)
# Attention sink: per-head learnable logit added to the softmax
# denominator (paper §2.3.3). One scalar per query head.
self.sink_logits: torch.Tensor | None = None # set by load_weights
# ------------------------------------------------------------------
def forward(
self,
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
cache: "LayerCacheHandle", # per-layer cache slice
) -> torch.Tensor: # (T, hidden_size) BF16
# Branch on attention type. Frozen at construction so this is a
# constant in any traced graph.
if self.attn_type == AttentionType.CSA:
return self._forward_csa(x, cache)
elif self.attn_type == AttentionType.HCA:
return self._forward_hca(x, cache)
else:
return self._forward_swa(x, cache)
# ------------------------------------------------------------------
# CSA: compressed + sparse attention, with indexer + top-k.
# ------------------------------------------------------------------
def _forward_csa(self, x, cache):
# 1) Low-rank query projection. c_Q is shared with the indexer.
c_Q = self.q_down(x) # (T, dc)
q = self.q_up(c_Q) # (T, n_h * head_dim)
# 2) KV down-projection — produces (Ca, Cb, Za, Zb) interleaved.
kv_raw = self.kv_down(x) # (T, 4 * head_dim)
# 3) Apply partial RoPE to q and the to-be-compressed keys.
q_roped = self._apply_rope(q, positions=cache.positions)
# 4) Compress + write to cache. Returns nothing — the cache holds it.
from dsv4.kernels.compressor import csa_compress_and_store
csa_compress_and_store(
kv_raw=kv_raw,
cache=cache, # writes C_comp + Z + indexer keys
positions=cache.positions,
compression_ratio=self.config.csa_compression_ratio,
)
# 5) Indexer: compute index scores, select top-k compressed entries.
q_indexer = self.indexer_q_up(c_Q) # (T, n_I_h * c_I)
w_indexer = self.indexer_head_weights(x) # (T, n_I_h)
from dsv4.kernels.indexer import compute_index_scores_topk
topk_indices = compute_index_scores_topk(
q_indexer=q_indexer,
w_indexer=w_indexer,
cache=cache, # provides compressed indexer keys
top_k=self.config.csa_top_k,
)
# 6) FMHA over selected sparse entries + SWA branch, fused sink merge.
from dsv4.kernels.attention import sparse_fmha_with_swa
attn_out = sparse_fmha_with_swa(
q=q_roped,
cache=cache,
selected_indices=topk_indices,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
) # (T, n_h * head_dim)
# 7) Inverse RoPE on outputs (paper §2.3.3).
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
# 8) Grouped output projection: wo_a then wo_b.
grouped = self.wo_a(attn_out) # (T, n_groups * d_g)
return self.wo_b(grouped) # (T, hidden_size)
# ------------------------------------------------------------------
# HCA: heavy compression, dense attention over compressed entries.
# ------------------------------------------------------------------
def _forward_hca(self, x, cache):
c_Q = self.q_down(x)
q = self.q_up(c_Q)
kv_raw = self.kv_down(x) # (T, 2 * head_dim)
q_roped = self._apply_rope(q, positions=cache.positions)
from dsv4.kernels.compressor import hca_compress_and_store
hca_compress_and_store(
kv_raw=kv_raw,
cache=cache,
positions=cache.positions,
compression_ratio=self.config.hca_compression_ratio,
)
# Dense FMHA over ALL compressed entries (not top-k) plus SWA.
from dsv4.kernels.attention import dense_fmha_with_swa
attn_out = dense_fmha_with_swa(
q=q_roped,
cache=cache,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
)
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
grouped = self.wo_a(attn_out)
return self.wo_b(grouped)
# ------------------------------------------------------------------
# SWA: sliding window only, no compression branch.
# First 2 layers of Flash. Acts as the "local context" foundation.
# ------------------------------------------------------------------
def _forward_swa(self, x, cache):
c_Q = self.q_down(x)
q = self.q_up(c_Q)
kv_raw = self.kv_down(x) # (T, head_dim)
q_roped = self._apply_rope(q, positions=cache.positions)
# Write raw KV to the SWA window in the cache. No compressor.
cache.write_swa(kv_raw, positions=cache.positions)
# Dense FMHA over the sliding window only.
from dsv4.kernels.attention import swa_only_fmha
attn_out = swa_only_fmha(
q=q_roped,
cache=cache,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
)
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
grouped = self.wo_a(attn_out)
return self.wo_b(grouped)
# ------------------------------------------------------------------
def _apply_rope(self, q: torch.Tensor, positions: torch.Tensor):
"""Apply RoPE to the last `rope_dim` dimensions of each head."""
from dsv4.ops.rope import forward_rope_partial
return forward_rope_partial(
q, positions=positions, rope_dim=self.config.rope_dim,
)