diag: print l1_out shape warning in shared expert

This commit is contained in:
2026-06-01 03:54:29 +00:00
parent db30c4acd6
commit a53936a17c

View File

@@ -316,6 +316,8 @@ class Nvfp4SharedExpert:
self._ensure_initialized()
l1_out = self._run_l1(hidden_states)
if l1_out.shape[1] < 2 * self.intermediate_size:
print(f" WARNING: l1_out shape {l1_out.shape} < expected (N, {2*self.intermediate_size})", flush=True)
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]