diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index f0bcca915..9e78b6164 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( ) from vllm.model_executor.layers.fused_moe.fused_moe import ( TritonExperts, - try_get_optimal_moe_config, ) from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, @@ -39,7 +38,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) -from .utils import _get_lora_device +from .utils import _get_lora_device, try_get_optimal_moe_lora_config class FusedMoEWithLoRA(BaseLayerWithLoRA): @@ -103,15 +102,21 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ) else: # fall back to the default config get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, + try_get_optimal_moe_lora_config, + w1_shape=layer.w13_weight.size(), + w2_shape=layer.w2_weight.size(), + rank=rank, + top_k=top_k, + dtype=config_dtype, + M=M, block_shape=layer.quant_method.moe_quant_config.block_shape, ) - shrink_config = get_config_func(M) - expand_config = get_config_func(M) + shrink_config = get_config_func( + op_type=f"fused_moe_lora_{op_prefix}_shrink" + ) + expand_config = get_config_func( + op_type=f"fused_moe_lora_{op_prefix}_expand" + ) shrink_config = self._normalize_keys(shrink_config) expand_config = self._normalize_keys(expand_config) return shrink_config, expand_config diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 7dd0df2e3..26d1a53f0 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -7,6 +7,9 @@ from enum import Enum import torch import torch.nn as nn +from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config +from vllm.utils.math_utils import next_power_of_2 + class LoRAMappingType(Enum): LANGUAGE = 1 @@ -80,3 +83,33 @@ def _fully_sharded_can_replace(can_replace): ) return dec + + +def try_get_optimal_moe_lora_config( + op_type: str, + w1_shape: tuple[int, ...], + w2_shape: tuple[int, ...], + rank: int, + top_k: int, + dtype: str | None, + M: int, + block_shape: list[int] | None = None, +) -> dict[str, int | None]: + config = try_get_optimal_moe_config( + w1_shape, w2_shape, top_k, dtype, M, block_shape + ).copy() + if op_type in [ + "fused_moe_lora_w13_shrink", + "fused_moe_lora_w2_shrink", + ]: + config["BLOCK_SIZE_N"] = min( + config.get("BLOCK_SIZE_N", 64), next_power_of_2(rank) + ) + elif op_type in [ + "fused_moe_lora_w13_expand", + "fused_moe_lora_w2_expand", + ]: + config["BLOCK_SIZE_K"] = max( + 16, min(config.get("BLOCK_SIZE_K", 32), next_power_of_2(rank)) + ) + return config diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index ed9e91645..66703a36a 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -13,6 +13,7 @@ from vllm import envs from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.platforms import current_platform +from vllm.utils.math_utils import next_power_of_2 logger = init_logger(__name__) is_batch_invariant = vllm_is_batch_invariant() @@ -223,14 +224,25 @@ def get_lora_op_configs( # The default config for fused_moe_lora ops elif op_type in [ "fused_moe_lora_w13_shrink", - "fused_moe_lora_w13_expand", "fused_moe_lora_w2_shrink", + ]: + default = { + "block_m": 64, + "block_n": min(64, next_power_of_2(rank)), + "block_k": 32, + "num_warps": 4, + "num_stages": 3, + "group_size_m": 8, + "split_k": 1, + } + elif op_type in [ + "fused_moe_lora_w13_expand", "fused_moe_lora_w2_expand", ]: default = { "block_m": 64, "block_n": 64, - "block_k": 32, + "block_k": max(16, min(32, next_power_of_2(rank))), "num_warps": 4, "num_stages": 3, "group_size_m": 8,