Add cuda.synchronize + better logits validation after lm_head
Catch CUDA errors at the source instead of seeing them surfaced at torch.topk. Print logits stats every step.
This commit is contained in:
@@ -829,12 +829,21 @@ def main():
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||||
# lm_head: BF16 for now — NVFP4 path has CUDA error on large-magnitude inputs
|
||||
# TODO: debug quantize_from_buffer for |X|>500 and re-enable NVFP4
|
||||
# lm_head: NVFP4 production GEMM
|
||||
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||||
lm_head_w = lm_w_raw # keep as BF16 for F.linear
|
||||
lm_head_lin = None # signal: use BF16 path
|
||||
print(" lm_head: BF16 F.linear (NVFP4 deferred)")
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4
|
||||
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
|
||||
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
|
||||
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
|
||||
lm_head_lin.gs = [lm_gs]
|
||||
lm_head_lin.ws2 = [None]
|
||||
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
lm_head_lin._use_runtime_gsa = True
|
||||
lm_head_lin.finalize_weights()
|
||||
lm_w = None
|
||||
print(" lm_head: NVFP4 production GEMM")
|
||||
final_norm_w = all_w.get("model.norm.weight")
|
||||
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
|
||||
|
||||
@@ -970,20 +979,17 @@ def main():
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
|
||||
if lm_head_lin is not None:
|
||||
logits = lm_head_lin(x_out)
|
||||
else:
|
||||
logits = torch.nn.functional.linear(x_out, lm_head_w)
|
||||
# Validate logits before sampling
|
||||
if step == 0 or torch.isnan(logits.float()).any().item():
|
||||
print(f" logits: shape={list(logits.shape)} dtype={logits.dtype} "
|
||||
f"min={logits.float().min().item():.1f} max={logits.float().max().item():.1f} "
|
||||
f"has_nan={torch.isnan(logits.float()).any().item()} "
|
||||
f"has_inf={torch.isinf(logits.float()).any().item()}", flush=True)
|
||||
if torch.isnan(logits.float()).any().item() or torch.isinf(logits.float()).any().item():
|
||||
print(f" NaN/Inf in logits at step {step}, aborting", flush=True)
|
||||
break
|
||||
|
||||
logits = lm_head_lin(x_out)
|
||||
torch.cuda.synchronize() # catch CUDA errors at source
|
||||
ls = logits.float()
|
||||
has_nan = torch.isnan(ls).any().item()
|
||||
has_inf = torch.isinf(ls).any().item()
|
||||
print(f" logits: shape={list(logits.shape)} dtype={logits.dtype} "
|
||||
f"min={ls.min().item():.1f} max={ls.max().item():.1f} "
|
||||
f"nan={has_nan} inf={has_inf}", flush=True)
|
||||
if has_nan or has_inf:
|
||||
print(f" NaN/Inf in logits at step {step}, aborting", flush=True)
|
||||
break
|
||||
# Sampling — fused CUDA kernel (or greedy argmax for temp=0)
|
||||
if is_greedy:
|
||||
next_id = torch.argmax(logits, -1).item()
|
||||
|
||||
Reference in New Issue
Block a user