Debug: add shape mismatch logging in MoE apply

This commit is contained in:
2026-05-19 04:35:58 +00:00
parent ffc1a5c6a8
commit cfd8ec741d

View File

@@ -313,4 +313,9 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
)
# Copy result into output tensor
if result.shape != output.shape:
import sys
print(f"[CuTeDSL MoE] SHAPE MISMATCH: result={result.shape} output={output.shape} "
f"hidden_dim={self.hidden_dim} w1={w1.shape if w1 is not None else None} "
f"hs={hidden_states.shape}", file=sys.stderr, flush=True)
output.copy_(result)