Fix workspace_shapes: remove wrong assertion, compute output dim from K
The framework may pass K in different forms (packed or unpacked). Use max(K*2, hidden_dim) to handle both cases.
This commit is contained in:
@@ -271,9 +271,13 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
||||
# Our runner manages its own workspace internally (pre-allocated buffers)
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
# K is packed (K//2 for uint8), so output uses hidden_dim
|
||||
assert self.hidden_dim == K * 2
|
||||
output = (M, self.hidden_dim)
|
||||
# K is the packed dimension from w1.shape[-1].
|
||||
# For NVFP4 uint8 packed weights, K_packed = K_logical // 2.
|
||||
# The output of the L2 GEMM is hidden_dim (unpacked).
|
||||
# If K == hidden_dim, weights are BF16 (not packed).
|
||||
# If K == hidden_dim // 2, weights are NVFP4 packed.
|
||||
output_dim = max(K * 2, self.hidden_dim)
|
||||
output = (M, output_dim)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
|
||||
Reference in New Issue
Block a user