debug more3
This commit is contained in:
@@ -285,6 +285,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
import os
|
||||
layer_idx = getattr(self, 'layer_idx', '?')
|
||||
_debug = int(os.environ.get('MEGA_MOE_DEBUG', '0'))
|
||||
|
||||
# NaN-trace: check attention inputs
|
||||
if _debug:
|
||||
hs_f32 = hidden_states.to(torch.float32)
|
||||
nf = torch.isnan(hs_f32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} attn-in/hidden_states] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(hs_f32).any().item()} "
|
||||
f"shape={tuple(hidden_states.shape)} dtype={hidden_states.dtype}")
|
||||
|
||||
# Pre-allocate attention output with FlashMLA-padded head count.
|
||||
# The op writes into `o_padded`; we slice to n_local_heads after.
|
||||
num_tokens = hidden_states.shape[0]
|
||||
@@ -303,6 +316,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# NaN-trace: check attention output
|
||||
if _debug:
|
||||
o_f32 = o.to(torch.float32)
|
||||
nf = torch.isnan(o_f32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} attn-out/o_sliced] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(o_f32).any().item()} "
|
||||
f"shape={tuple(o.shape)} dtype={o.dtype}")
|
||||
|
||||
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
|
||||
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
||||
o,
|
||||
@@ -315,6 +337,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
tma_aligned_scales=self._tma_aligned_scales,
|
||||
)
|
||||
|
||||
# NaN-trace: check rope+quant output
|
||||
if _debug:
|
||||
of32 = o_fp8.to(torch.float32)
|
||||
nf = torch.isnan(of32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} rope-quant/o_fp8] nan_frac={nf:.4f} "
|
||||
f"shape={tuple(o_fp8.shape)} dtype={o_fp8.dtype}")
|
||||
sf32 = o_scale.to(torch.float32)
|
||||
nf2 = torch.isnan(sf32).float().mean().item()
|
||||
if nf2 > 0:
|
||||
print(f"[NAN @ L{layer_idx} rope-quant/o_scale] nan_frac={nf2:.4f} "
|
||||
f"shape={tuple(o_scale.shape)} dtype={o_scale.dtype}")
|
||||
|
||||
wo_a_fp8 = self.wo_a.weight
|
||||
wo_a_scale = self.wo_a.weight_scale_inv
|
||||
|
||||
@@ -333,7 +368,27 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
list(self._einsum_recipe),
|
||||
)
|
||||
|
||||
return self.wo_b(z.flatten(1))
|
||||
# NaN-trace: check wo_a einsum output
|
||||
if _debug:
|
||||
zf32 = z.to(torch.float32)
|
||||
nf = torch.isnan(zf32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} wo_a-einsum/z] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(zf32).any().item()} "
|
||||
f"shape={tuple(z.shape)} dtype={z.dtype}")
|
||||
|
||||
result = self.wo_b(z.flatten(1))
|
||||
|
||||
# NaN-trace: check final wo_b output
|
||||
if _debug:
|
||||
rf32 = result.to(torch.float32)
|
||||
nf = torch.isnan(rf32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} wo_b/result] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(rf32).any().item()} "
|
||||
f"shape={tuple(result.shape)} dtype={result.dtype}")
|
||||
|
||||
return result
|
||||
|
||||
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
|
||||
assert self.aux_stream_list is not None
|
||||
|
||||
Reference in New Issue
Block a user