Files
nvfp4-megamoe-kernel/dsv4/layers/attention.py

246 lines
10 KiB
Python
Raw Normal View History

"""DSV4 attention sub-block — CSA / HCA / SWA variants.
The constructor decides which variant this instance is by reading the
LayerSpec passed in. After construction, forward() is fixed no runtime
branching on attention type, so torch.compile can fold the dispatch
completely.
What this class is responsible for:
- Holding the projection weights (Q-down/up, KV-down, output groups).
- Calling Nvfp4Linear for the projections.
- Running RoPE (partial, last rope_dim dimensions).
- Calling the compressor (CSA or HCA) to write into the KV cache.
- For CSA only: calling the indexer + top-k to select sparse entries.
- Calling the FMHA kernel with the right K/V source.
- Calling inverse RoPE on the output.
- Calling Nvfp4GroupedLinear (wo_a) + Nvfp4Linear (wo_b).
What this class does NOT touch:
- mHC pre/post done by TransformerLayer.
- Pre-norm done by TransformerLayer.
- KV cache allocation / eviction done by the cache handle.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from dsv4.layers.linear import Nvfp4Linear
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
from dsv4.model.layer_schedule import AttentionType, LayerSpec
if TYPE_CHECKING:
from dsv4.model.config import DSV4Config
from dsv4.cache.paged_cache import LayerCacheHandle
class AttentionSubBlock:
def __init__(self, config: "DSV4Config", spec: LayerSpec):
self.config = config
self.spec = spec
self.layer_idx = spec.layer_idx
self.attn_type = spec.attn # frozen for this layer's lifetime
# ---- Projections (all NVFP4) ----
# Query path is low-rank: hidden -> dc -> head_dim * num_heads.
self.q_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.query_compression_dim,
)
self.q_up = Nvfp4Linear(
in_features=config.query_compression_dim,
out_features=config.head_dim * config.num_query_heads,
)
# KV projection — single head shared across queries (MQA).
# CSA needs two streams (a, b) for the overlapping compression;
# HCA needs one. SWA needs raw KV with no compression.
if self.attn_type == AttentionType.CSA:
# 2 * head_dim for (Ca, Cb), 2 * head_dim for compression weights (Za, Zb)
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=4 * config.head_dim,
)
elif self.attn_type == AttentionType.HCA:
# head_dim for C, head_dim for Z
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=2 * config.head_dim,
)
else: # SWA
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.head_dim,
)
# ---- Indexer (CSA only) ----
if self.attn_type == AttentionType.CSA:
# Indexer queries share the q_down latent c_Q with main queries.
# Up-projection: c_Q -> indexer_num_heads * indexer_head_dim.
self.indexer_q_up = Nvfp4Linear(
in_features=config.query_compression_dim,
out_features=config.indexer_num_heads * config.indexer_head_dim,
)
# Per-head weight vector w_h: hidden -> indexer_num_heads.
self.indexer_head_weights = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.indexer_num_heads,
max_num_tokens=1, # scalar per head, not per-token projection
)
# ---- Output projection ----
# wo_a: grouped, splits the n_heads outputs into n_groups and projects
# each group from (heads_per_group * head_dim) to o_lora_rank.
heads_per_group = config.num_query_heads // config.num_output_groups
self.wo_a = Nvfp4GroupedLinear(
n_local_groups=config.num_output_groups,
heads_per_group=heads_per_group,
head_dim=config.head_dim,
o_lora_rank=config.output_group_dim,
)
# wo_b: dense, concatenated group outputs back to hidden_size.
self.wo_b = Nvfp4Linear(
in_features=config.output_group_dim * config.num_output_groups,
out_features=config.hidden_size,
)
# Attention sink: per-head learnable logit added to the softmax
# denominator (paper §2.3.3). One scalar per query head.
self.sink_logits: torch.Tensor | None = None # set by load_weights
# ------------------------------------------------------------------
def forward(
self,
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
cache: "LayerCacheHandle", # per-layer cache slice
) -> torch.Tensor: # (T, hidden_size) BF16
# Branch on attention type. Frozen at construction so this is a
# constant in any traced graph.
if self.attn_type == AttentionType.CSA:
return self._forward_csa(x, cache)
elif self.attn_type == AttentionType.HCA:
return self._forward_hca(x, cache)
else:
return self._forward_swa(x, cache)
# ------------------------------------------------------------------
# CSA: compressed + sparse attention, with indexer + top-k.
# ------------------------------------------------------------------
def _forward_csa(self, x, cache):
# 1) Low-rank query projection. c_Q is shared with the indexer.
c_Q = self.q_down(x) # (T, dc)
q = self.q_up(c_Q) # (T, n_h * head_dim)
# 2) KV down-projection — produces (Ca, Cb, Za, Zb) interleaved.
kv_raw = self.kv_down(x) # (T, 4 * head_dim)
# 3) Apply partial RoPE to q and the to-be-compressed keys.
q_roped = self._apply_rope(q, positions=cache.positions)
# 4) Compress + write to cache. Returns nothing — the cache holds it.
from dsv4.kernels.compressor import csa_compress_and_store
csa_compress_and_store(
kv_raw=kv_raw,
cache=cache, # writes C_comp + Z + indexer keys
positions=cache.positions,
compression_ratio=self.config.csa_compression_ratio,
)
# 5) Indexer: compute index scores, select top-k compressed entries.
q_indexer = self.indexer_q_up(c_Q) # (T, n_I_h * c_I)
w_indexer = self.indexer_head_weights(x) # (T, n_I_h)
from dsv4.kernels.indexer import compute_index_scores_topk
topk_indices = compute_index_scores_topk(
q_indexer=q_indexer,
w_indexer=w_indexer,
cache=cache, # provides compressed indexer keys
top_k=self.config.csa_top_k,
)
# 6) FMHA over selected sparse entries + SWA branch, fused sink merge.
from dsv4.kernels.attention import sparse_fmha_with_swa
attn_out = sparse_fmha_with_swa(
q=q_roped,
cache=cache,
selected_indices=topk_indices,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
) # (T, n_h * head_dim)
# 7) Inverse RoPE on outputs (paper §2.3.3).
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
# 8) Grouped output projection: wo_a then wo_b.
grouped = self.wo_a(attn_out) # (T, n_groups * d_g)
return self.wo_b(grouped) # (T, hidden_size)
# ------------------------------------------------------------------
# HCA: heavy compression, dense attention over compressed entries.
# ------------------------------------------------------------------
def _forward_hca(self, x, cache):
c_Q = self.q_down(x)
q = self.q_up(c_Q)
kv_raw = self.kv_down(x) # (T, 2 * head_dim)
q_roped = self._apply_rope(q, positions=cache.positions)
from dsv4.kernels.compressor import hca_compress_and_store
hca_compress_and_store(
kv_raw=kv_raw,
cache=cache,
positions=cache.positions,
compression_ratio=self.config.hca_compression_ratio,
)
# Dense FMHA over ALL compressed entries (not top-k) plus SWA.
from dsv4.kernels.attention import dense_fmha_with_swa
attn_out = dense_fmha_with_swa(
q=q_roped,
cache=cache,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
)
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
grouped = self.wo_a(attn_out)
return self.wo_b(grouped)
# ------------------------------------------------------------------
# SWA: sliding window only, no compression branch.
# First 2 layers of Flash. Acts as the "local context" foundation.
# ------------------------------------------------------------------
def _forward_swa(self, x, cache):
c_Q = self.q_down(x)
q = self.q_up(c_Q)
kv_raw = self.kv_down(x) # (T, head_dim)
q_roped = self._apply_rope(q, positions=cache.positions)
# Write raw KV to the SWA window in the cache. No compressor.
cache.write_swa(kv_raw, positions=cache.positions)
# Dense FMHA over the sliding window only.
from dsv4.kernels.attention import swa_only_fmha
attn_out = swa_only_fmha(
q=q_roped,
cache=cache,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
)
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
grouped = self.wo_a(attn_out)
return self.wo_b(grouped)
# ------------------------------------------------------------------
def _apply_rope(self, q: torch.Tensor, positions: torch.Tensor):
"""Apply RoPE to the last `rope_dim` dimensions of each head."""
from dsv4.ops.rope import forward_rope_partial
return forward_rope_partial(
q, positions=positions, rope_dim=self.config.rope_dim,
)