[Bugfix] Fix marlin moe fallback logic for llama4 (#18042)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -171,13 +171,19 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
-> bool:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||
# apply_router_weight_on_input is not supported for moe marlin
|
||||
supports_router_weight = not layer.apply_router_weight_on_input
|
||||
# moe marlin requires the activation to be silu
|
||||
supports_activation = layer.activation == "silu"
|
||||
|
||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||
# moe marlin requires n % 128 == 0 and k % 64 == 0
|
||||
return hidden_size % 128 == 0 and \
|
||||
intermediate_size_per_partition % max(64, group_size) == 0 and \
|
||||
group_size in [-1, 32, 64, 128]
|
||||
supports_shape = hidden_size % 128 == 0 and \
|
||||
intermediate_size_per_partition % max(64, group_size) == 0
|
||||
supports_group_size = group_size in [-1, 32, 64, 128]
|
||||
return supports_shape and supports_group_size and \
|
||||
supports_router_weight and supports_activation
|
||||
|
||||
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
|
||||
Reference in New Issue
Block a user