- Token embedding → N×TransformerLayer → RMSNorm → lm_head - decode_step: single token decode with mHC state management - forward: prefill path (T tokens) - Cache handle acquisition per layer - mHC state initialization from embedding - Weight loading TODO (deferred to loader/)
177 lines
6.5 KiB
Python
177 lines
6.5 KiB
Python
"""Full DSV4 model — embedding → N×DSV4Layer → final norm → prediction head.
|
||
|
||
Supports both Flash and Pro variants via DSV4Config.
|
||
MTP (multi-token prediction) is wired but optional (off by default).
|
||
"""
|
||
from __future__ import annotations
|
||
from typing import Optional, List
|
||
import torch
|
||
|
||
from dsv4.model.config import DSV4Config
|
||
from dsv4.model.layer import TransformerLayer
|
||
from dsv4.model.layer_schedule import build_schedule, LayerSpec
|
||
from dsv4.layers.norm import RMSNorm
|
||
from dsv4.layers.linear import Nvfp4Linear
|
||
from dsv4.cache.manager import KVCacheManager
|
||
|
||
|
||
class DSV4Model:
|
||
"""Full DeepSeek-V4 model for inference.
|
||
|
||
Construction:
|
||
config = DSV4Config.pro()
|
||
model = DSV4Model(config, cache_manager)
|
||
model.load_weights(checkpoint_path)
|
||
|
||
Decode step:
|
||
logits = model.decode_step(token_ids, positions, request_ids)
|
||
next_token = sampler(logits)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: DSV4Config,
|
||
cache_manager: KVCacheManager,
|
||
device: str = "cuda",
|
||
):
|
||
self.config = config
|
||
self.cache_manager = cache_manager
|
||
self.device = device
|
||
self.schedule = build_schedule(config)
|
||
|
||
# ---- Token embedding ----
|
||
self.token_embedding = torch.nn.Embedding(
|
||
config.vocab_size, config.hidden_size, device=device,
|
||
)
|
||
|
||
# ---- Transformer layers ----
|
||
self.layers: List[TransformerLayer] = [
|
||
TransformerLayer(config, spec) for spec in self.schedule
|
||
]
|
||
|
||
# ---- Final norm ----
|
||
self.final_norm = RMSNorm(config.hidden_size, device=device)
|
||
|
||
# ---- Prediction head (tied with embedding or separate) ----
|
||
self.lm_head = Nvfp4Linear(
|
||
in_features=config.hidden_size,
|
||
out_features=config.vocab_size,
|
||
)
|
||
self._weights_tied = False
|
||
|
||
# mHC state shape: (batch, n_hc, hidden_size) per layer
|
||
self.n_hc = config.n_hc
|
||
|
||
def load_weights(self, checkpoint_path: str) -> None:
|
||
"""Load weights from a checkpoint directory.
|
||
|
||
TODO: implement HF checkpoint loader (dsv4/loader/hf_checkpoint.py).
|
||
For now, weights must be set manually.
|
||
"""
|
||
raise NotImplementedError(
|
||
"Weight loading not yet implemented. "
|
||
"See dsv4/loader/hf_checkpoint.py"
|
||
)
|
||
|
||
def tie_weights(self) -> None:
|
||
"""Tie lm_head weights to token_embedding (common in LLMs)."""
|
||
self.lm_head.weight = self.token_embedding.weight
|
||
self._weights_tied = True
|
||
|
||
def decode_step(
|
||
self,
|
||
token_ids: torch.Tensor, # (batch,) int64
|
||
positions: torch.Tensor, # (batch,) int64
|
||
request_ids: torch.Tensor, # (batch,) int32
|
||
mhc_states: Optional[List[torch.Tensor]] = None,
|
||
) -> tuple[torch.Tensor, List[torch.Tensor]]:
|
||
"""Single decode step: token_ids → logits.
|
||
|
||
Args:
|
||
token_ids: (batch,) int64 — token IDs for this step
|
||
positions: (batch,) int64 — absolute positions
|
||
request_ids: (batch,) int32 — which request each token belongs to
|
||
mhc_states: optional list of (batch, n_hc, hidden_size) BF16 per layer.
|
||
If None, initialized from embedding (first step).
|
||
|
||
Returns:
|
||
(logits, updated_mhc_states)
|
||
logits: (batch, vocab_size) BF16
|
||
mhc_states: List[Tensor] per layer
|
||
"""
|
||
batch = token_ids.shape[0]
|
||
T = 1 # decode
|
||
|
||
# Embed → (batch, hidden_size) → (T, batch, hidden_size) → mHC state
|
||
emb = self.token_embedding(token_ids) # (batch, hidden_size)
|
||
|
||
# mHC state: X_0 = expand to (T, n_hc, hidden_size)
|
||
# At layer 0, the first mHC state is initialized from the embedding.
|
||
# X[0, i, :] = emb for all i (paper: identity initialization before Sinkhorn)
|
||
if mhc_states is None:
|
||
mhc_states = []
|
||
for _ in self.layers:
|
||
x = torch.zeros(batch, self.n_hc, self.config.hidden_size,
|
||
dtype=torch.bfloat16, device=self.device)
|
||
# First layer: broadcast embedding into all n_hc slots
|
||
x[:, :, :] = emb.unsqueeze(1)
|
||
mhc_states.append(x)
|
||
|
||
# Get cache handles for each layer
|
||
# request_slots are the state cache slots for these requests
|
||
request_slots = self.cache_manager.request_slot_map[:batch].clone()
|
||
|
||
for layer_idx, (layer, spec) in enumerate(zip(self.layers, self.schedule)):
|
||
cache = self.cache_manager.acquire(
|
||
layer_idx, request_slots, positions, request_ids,
|
||
)
|
||
|
||
X = mhc_states[layer_idx] # (batch, n_hc, hidden_size)
|
||
# TransformerLayer expects (T, n_hc, hidden_size)
|
||
X = X.unsqueeze(0) # (1, n_hc, hidden_size) — T=1
|
||
|
||
# token_ids needed for hash routing
|
||
X = layer.forward(X, token_ids, cache)
|
||
|
||
mhc_states[layer_idx] = X.squeeze(0) # (batch, n_hc, hidden_size)
|
||
|
||
# Final output: take the last mHC channel and apply norm + head
|
||
# X has been updated in-place by the last layer's mHC.post_block
|
||
# The output is from the first channel (paper: identity residual)
|
||
x_out = mhc_states[-1][:, 0, :] # (batch, hidden_size)
|
||
x_out = self.final_norm(x_out)
|
||
logits = self.lm_head(x_out) # (batch, vocab_size)
|
||
return logits, mhc_states
|
||
|
||
def forward(
|
||
self,
|
||
token_ids: torch.Tensor, # (T,) int64 — prefill tokens
|
||
positions: torch.Tensor, # (T,) int64
|
||
request_ids: torch.Tensor, # (T,) int32
|
||
request_slots: torch.Tensor, # (batch,) int32
|
||
) -> torch.Tensor:
|
||
"""Prefill: process a sequence of tokens.
|
||
|
||
Returns:
|
||
(T, vocab_size) BF16 logits
|
||
"""
|
||
T = token_ids.shape[0]
|
||
emb = self.token_embedding(token_ids) # (T, hidden_size)
|
||
|
||
for layer_idx, (layer, spec) in enumerate(zip(self.layers, self.schedule)):
|
||
cache = self.cache_manager.acquire(
|
||
layer_idx, request_slots, positions, request_ids,
|
||
)
|
||
# Initialize mHC state from embedding
|
||
X = torch.zeros(T, self.n_hc, self.config.hidden_size,
|
||
dtype=torch.bfloat16, device=self.device)
|
||
X[:, :, :] = emb.unsqueeze(1)
|
||
|
||
X = layer.forward(X, token_ids, cache)
|
||
|
||
# Output from last layer
|
||
x_out = X[:, 0, :] # (T, hidden_size)
|
||
x_out = self.final_norm(x_out)
|
||
logits = self.lm_head(x_out) # (T, vocab_size)
|
||
return logits
|