[Kernel] moe wna16 marlin kernel (#14447)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <michael@neuralmagic.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
group_size=group_size)[0]
|
||||
|
||||
|
||||
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
-> bool:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||
|
||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||
# moe marlin requires n % 128 == 0 and k % 64 == 0
|
||||
return hidden_size % 128 == 0 and \
|
||||
intermediate_size_per_partition % max(64, group_size) == 0 and \
|
||||
group_size in [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
max_workspace_size = (output_size_per_partition //
|
||||
|
||||
Reference in New Issue
Block a user