Files
nvfp4-megamoe-kernel/dsv4/model/dsv4.py
biondizzle d3b772196d E3: Implement DSV4Model — full model class
- Token embedding → N×TransformerLayer → RMSNorm → lm_head
- decode_step: single token decode with mHC state management
- forward: prefill path (T tokens)
- Cache handle acquisition per layer
- mHC state initialization from embedding
- Weight loading TODO (deferred to loader/)
2026-05-30 21:15:57 +00:00

177 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Full DSV4 model — embedding → N×DSV4Layer → final norm → prediction head.
Supports both Flash and Pro variants via DSV4Config.
MTP (multi-token prediction) is wired but optional (off by default).
"""
from __future__ import annotations
from typing import Optional, List
import torch
from dsv4.model.config import DSV4Config
from dsv4.model.layer import TransformerLayer
from dsv4.model.layer_schedule import build_schedule, LayerSpec
from dsv4.layers.norm import RMSNorm
from dsv4.layers.linear import Nvfp4Linear
from dsv4.cache.manager import KVCacheManager
class DSV4Model:
"""Full DeepSeek-V4 model for inference.
Construction:
config = DSV4Config.pro()
model = DSV4Model(config, cache_manager)
model.load_weights(checkpoint_path)
Decode step:
logits = model.decode_step(token_ids, positions, request_ids)
next_token = sampler(logits)
"""
def __init__(
self,
config: DSV4Config,
cache_manager: KVCacheManager,
device: str = "cuda",
):
self.config = config
self.cache_manager = cache_manager
self.device = device
self.schedule = build_schedule(config)
# ---- Token embedding ----
self.token_embedding = torch.nn.Embedding(
config.vocab_size, config.hidden_size, device=device,
)
# ---- Transformer layers ----
self.layers: List[TransformerLayer] = [
TransformerLayer(config, spec) for spec in self.schedule
]
# ---- Final norm ----
self.final_norm = RMSNorm(config.hidden_size, device=device)
# ---- Prediction head (tied with embedding or separate) ----
self.lm_head = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.vocab_size,
)
self._weights_tied = False
# mHC state shape: (batch, n_hc, hidden_size) per layer
self.n_hc = config.n_hc
def load_weights(self, checkpoint_path: str) -> None:
"""Load weights from a checkpoint directory.
TODO: implement HF checkpoint loader (dsv4/loader/hf_checkpoint.py).
For now, weights must be set manually.
"""
raise NotImplementedError(
"Weight loading not yet implemented. "
"See dsv4/loader/hf_checkpoint.py"
)
def tie_weights(self) -> None:
"""Tie lm_head weights to token_embedding (common in LLMs)."""
self.lm_head.weight = self.token_embedding.weight
self._weights_tied = True
def decode_step(
self,
token_ids: torch.Tensor, # (batch,) int64
positions: torch.Tensor, # (batch,) int64
request_ids: torch.Tensor, # (batch,) int32
mhc_states: Optional[List[torch.Tensor]] = None,
) -> tuple[torch.Tensor, List[torch.Tensor]]:
"""Single decode step: token_ids → logits.
Args:
token_ids: (batch,) int64 — token IDs for this step
positions: (batch,) int64 — absolute positions
request_ids: (batch,) int32 — which request each token belongs to
mhc_states: optional list of (batch, n_hc, hidden_size) BF16 per layer.
If None, initialized from embedding (first step).
Returns:
(logits, updated_mhc_states)
logits: (batch, vocab_size) BF16
mhc_states: List[Tensor] per layer
"""
batch = token_ids.shape[0]
T = 1 # decode
# Embed → (batch, hidden_size) → (T, batch, hidden_size) → mHC state
emb = self.token_embedding(token_ids) # (batch, hidden_size)
# mHC state: X_0 = expand to (T, n_hc, hidden_size)
# At layer 0, the first mHC state is initialized from the embedding.
# X[0, i, :] = emb for all i (paper: identity initialization before Sinkhorn)
if mhc_states is None:
mhc_states = []
for _ in self.layers:
x = torch.zeros(batch, self.n_hc, self.config.hidden_size,
dtype=torch.bfloat16, device=self.device)
# First layer: broadcast embedding into all n_hc slots
x[:, :, :] = emb.unsqueeze(1)
mhc_states.append(x)
# Get cache handles for each layer
# request_slots are the state cache slots for these requests
request_slots = self.cache_manager.request_slot_map[:batch].clone()
for layer_idx, (layer, spec) in enumerate(zip(self.layers, self.schedule)):
cache = self.cache_manager.acquire(
layer_idx, request_slots, positions, request_ids,
)
X = mhc_states[layer_idx] # (batch, n_hc, hidden_size)
# TransformerLayer expects (T, n_hc, hidden_size)
X = X.unsqueeze(0) # (1, n_hc, hidden_size) — T=1
# token_ids needed for hash routing
X = layer.forward(X, token_ids, cache)
mhc_states[layer_idx] = X.squeeze(0) # (batch, n_hc, hidden_size)
# Final output: take the last mHC channel and apply norm + head
# X has been updated in-place by the last layer's mHC.post_block
# The output is from the first channel (paper: identity residual)
x_out = mhc_states[-1][:, 0, :] # (batch, hidden_size)
x_out = self.final_norm(x_out)
logits = self.lm_head(x_out) # (batch, vocab_size)
return logits, mhc_states
def forward(
self,
token_ids: torch.Tensor, # (T,) int64 — prefill tokens
positions: torch.Tensor, # (T,) int64
request_ids: torch.Tensor, # (T,) int32
request_slots: torch.Tensor, # (batch,) int32
) -> torch.Tensor:
"""Prefill: process a sequence of tokens.
Returns:
(T, vocab_size) BF16 logits
"""
T = token_ids.shape[0]
emb = self.token_embedding(token_ids) # (T, hidden_size)
for layer_idx, (layer, spec) in enumerate(zip(self.layers, self.schedule)):
cache = self.cache_manager.acquire(
layer_idx, request_slots, positions, request_ids,
)
# Initialize mHC state from embedding
X = torch.zeros(T, self.n_hc, self.config.hidden_size,
dtype=torch.bfloat16, device=self.device)
X[:, :, :] = emb.unsqueeze(1)
X = layer.forward(X, token_ids, cache)
# Output from last layer
x_out = X[:, 0, :] # (T, hidden_size)
x_out = self.final_norm(x_out)
logits = self.lm_head(x_out) # (T, vocab_size)
return logits