Fix layer construction: match existing API signatures, add RMSNorm impl
- Nvfp4GroupedLinear: (n_local_groups, heads_per_group, head_dim, o_lora_rank) - mHCLayer: hidden_dim, t_max_sinkhorn (not hidden_size, sinkhorn_iters) - RMSNorm: PyTorch reference implementation (BF16, cudagraph-safe) - Verified: all 43 Flash + 61 Pro layers construct cleanly - All projection shapes validated against architecture spec
This commit is contained in:
@@ -85,15 +85,18 @@ class AttentionSubBlock:
|
||||
self.indexer_head_weights = Nvfp4Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.indexer_num_heads,
|
||||
max_num_tokens=1, # scalar per head, not per-token projection
|
||||
)
|
||||
|
||||
# ---- Output projection ----
|
||||
# wo_a: grouped, splits the n_heads outputs into n_groups and projects
|
||||
# each group from (head_dim * n_heads / n_groups) to output_group_dim.
|
||||
# each group from (heads_per_group * head_dim) to o_lora_rank.
|
||||
heads_per_group = config.num_query_heads // config.num_output_groups
|
||||
self.wo_a = Nvfp4GroupedLinear(
|
||||
num_groups=config.num_output_groups,
|
||||
in_features=(config.head_dim * config.num_query_heads) // config.num_output_groups,
|
||||
out_features=config.output_group_dim,
|
||||
n_local_groups=config.num_output_groups,
|
||||
heads_per_group=heads_per_group,
|
||||
head_dim=config.head_dim,
|
||||
o_lora_rank=config.output_group_dim,
|
||||
)
|
||||
# wo_b: dense, concatenated group outputs back to hidden_size.
|
||||
self.wo_b = Nvfp4Linear(
|
||||
|
||||
@@ -1,2 +1,30 @@
|
||||
"""RMSNorm placeholder."""
|
||||
# TODO: Implement RMSNorm
|
||||
"""RMSNorm — PyTorch reference implementation.
|
||||
|
||||
Swap to fused kernel (CuTeDSL) in Phase 6. API won't change.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
class RMSNorm:
|
||||
"""Root Mean Square Layer Normalization.
|
||||
|
||||
y = x / sqrt(mean(x^2) + eps) * weight
|
||||
|
||||
CUDA-graph-compatible: weight is a buffer, no CPU syncs.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6, device: str = "cuda"):
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.device = device
|
||||
self.weight: torch.Tensor | None = None # (hidden_size,) FP32, set by load_weights
|
||||
|
||||
def load_weights(self, weight: torch.Tensor) -> None:
|
||||
assert weight.shape == (self.hidden_size,), f"weight shape {weight.shape} != ({self.hidden_size},)"
|
||||
self.weight = weight.to(device=self.device, dtype=torch.float32)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""x: (T, hidden_size) BF16 -> (T, hidden_size) BF16"""
|
||||
x_f = x.float()
|
||||
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
||||
return (x_f * rms * self.weight).to(torch.bfloat16)
|
||||
|
||||
@@ -43,14 +43,14 @@ class TransformerLayer:
|
||||
# Two mHC wrappers — one per sub-block. mHCLayer holds its own
|
||||
# projection weights (W_pre, W_res, W_post) and static biases.
|
||||
self.mhc_attn = mHCLayer(
|
||||
hidden_size=config.hidden_size,
|
||||
hidden_dim=config.hidden_size,
|
||||
n_hc=config.n_hc,
|
||||
sinkhorn_iters=config.sinkhorn_iters,
|
||||
t_max_sinkhorn=config.sinkhorn_iters,
|
||||
)
|
||||
self.mhc_ffn = mHCLayer(
|
||||
hidden_size=config.hidden_size,
|
||||
hidden_dim=config.hidden_size,
|
||||
n_hc=config.n_hc,
|
||||
sinkhorn_iters=config.sinkhorn_iters,
|
||||
t_max_sinkhorn=config.sinkhorn_iters,
|
||||
)
|
||||
|
||||
# Pre-block norms (one per sub-block).
|
||||
|
||||
Reference in New Issue
Block a user