debug: add logit quality dump in compute_logits (ungated, once)
This commit is contained in:
@@ -2184,6 +2184,17 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
if not getattr(self, '_logit_dumped', False):
|
||||
self._logit_dumped = True
|
||||
last = logits[-1].float()
|
||||
top_vals, top_idx = last.topk(10)
|
||||
print(f"[LOGITS] shape={tuple(logits.shape)} "
|
||||
f"max={last.max().item():.4e} min={last.min().item():.4e} "
|
||||
f"mean={last.mean().item():.4e} std={last.std().item():.4e} "
|
||||
f"finite_frac={torch.isfinite(last).float().mean().item():.4f}")
|
||||
print(f"[LOGITS] top10 ids: {top_idx.tolist()}")
|
||||
print(f"[LOGITS] top10 vals: {[f'{v:.3f}' for v in top_vals.tolist()]}")
|
||||
print(f"[LOGITS] gap top1-top10: {(top_vals[0] - top_vals[-1]).item():.3f}")
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user