Debug: add shape mismatch logging in MoE apply
This commit is contained in:
@@ -313,4 +313,9 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
||||
)
|
||||
|
||||
# Copy result into output tensor
|
||||
if result.shape != output.shape:
|
||||
import sys
|
||||
print(f"[CuTeDSL MoE] SHAPE MISMATCH: result={result.shape} output={output.shape} "
|
||||
f"hidden_dim={self.hidden_dim} w1={w1.shape if w1 is not None else None} "
|
||||
f"hs={hidden_states.shape}", file=sys.stderr, flush=True)
|
||||
output.copy_(result)
|
||||
|
||||
Reference in New Issue
Block a user