debug more3

This commit is contained in:
2026-05-14 22:36:34 +00:00
parent 7573f12659
commit fd5f04eb15

View File

@@ -285,6 +285,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
import os
layer_idx = getattr(self, 'layer_idx', '?')
_debug = int(os.environ.get('MEGA_MOE_DEBUG', '0'))
# NaN-trace: check attention inputs
if _debug:
hs_f32 = hidden_states.to(torch.float32)
nf = torch.isnan(hs_f32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} attn-in/hidden_states] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(hs_f32).any().item()} "
f"shape={tuple(hidden_states.shape)} dtype={hidden_states.dtype}")
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
@@ -303,6 +316,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
o = o_padded[:, : self.n_local_heads, :]
# NaN-trace: check attention output
if _debug:
o_f32 = o.to(torch.float32)
nf = torch.isnan(o_f32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} attn-out/o_sliced] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(o_f32).any().item()} "
f"shape={tuple(o.shape)} dtype={o.dtype}")
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
@@ -315,6 +337,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
tma_aligned_scales=self._tma_aligned_scales,
)
# NaN-trace: check rope+quant output
if _debug:
of32 = o_fp8.to(torch.float32)
nf = torch.isnan(of32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} rope-quant/o_fp8] nan_frac={nf:.4f} "
f"shape={tuple(o_fp8.shape)} dtype={o_fp8.dtype}")
sf32 = o_scale.to(torch.float32)
nf2 = torch.isnan(sf32).float().mean().item()
if nf2 > 0:
print(f"[NAN @ L{layer_idx} rope-quant/o_scale] nan_frac={nf2:.4f} "
f"shape={tuple(o_scale.shape)} dtype={o_scale.dtype}")
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
@@ -333,7 +368,27 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
list(self._einsum_recipe),
)
return self.wo_b(z.flatten(1))
# NaN-trace: check wo_a einsum output
if _debug:
zf32 = z.to(torch.float32)
nf = torch.isnan(zf32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} wo_a-einsum/z] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(zf32).any().item()} "
f"shape={tuple(z.shape)} dtype={z.dtype}")
result = self.wo_b(z.flatten(1))
# NaN-trace: check final wo_b output
if _debug:
rf32 = result.to(torch.float32)
nf = torch.isnan(rf32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} wo_b/result] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(rf32).any().item()} "
f"shape={tuple(result.shape)} dtype={result.dtype}")
return result
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
assert self.aux_stream_list is not None