debug: add logit quality dump in compute_logits (ungated, once)

This commit is contained in:
2026-05-15 02:37:23 +00:00
parent 29f8b8c174
commit 3600a4b06a

View File

@@ -2184,6 +2184,17 @@ class DeepseekV4ForCausalLM(nn.Module):
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
if not getattr(self, '_logit_dumped', False):
self._logit_dumped = True
last = logits[-1].float()
top_vals, top_idx = last.topk(10)
print(f"[LOGITS] shape={tuple(logits.shape)} "
f"max={last.max().item():.4e} min={last.min().item():.4e} "
f"mean={last.mean().item():.4e} std={last.std().item():.4e} "
f"finite_frac={torch.isfinite(last).float().mean().item():.4f}")
print(f"[LOGITS] top10 ids: {top_idx.tolist()}")
print(f"[LOGITS] top10 vals: {[f'{v:.3f}' for v in top_vals.tolist()]}")
print(f"[LOGITS] gap top1-top10: {(top_vals[0] - top_vals[-1]).item():.3f}")
return logits
def forward(