Fix workspace_shapes: remove wrong assertion, compute output dim from K

The framework may pass K in different forms (packed or unpacked).
Use max(K*2, hidden_dim) to handle both cases.
This commit is contained in:
2026-05-19 04:28:04 +00:00
parent f023b3b2c6
commit ffc1a5c6a8

View File

@@ -271,9 +271,13 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
# Our runner manages its own workspace internally (pre-allocated buffers)
workspace1 = (0,)
workspace2 = (0,)
# K is packed (K//2 for uint8), so output uses hidden_dim
assert self.hidden_dim == K * 2
output = (M, self.hidden_dim)
# K is the packed dimension from w1.shape[-1].
# For NVFP4 uint8 packed weights, K_packed = K_logical // 2.
# The output of the L2 GEMM is hidden_dim (unpacked).
# If K == hidden_dim, weights are BF16 (not packed).
# If K == hidden_dim // 2, weights are NVFP4 packed.
output_dim = max(K * 2, self.hidden_dim)
output = (M, output_dim)
return (workspace1, workspace2, output)
def apply(