[LoRA][Perf] Improve FusedMoE LoRA performance for small rank (#32019)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
@@ -39,7 +38,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
|
||||
from .utils import _get_lora_device
|
||||
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
|
||||
|
||||
|
||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
@@ -103,15 +102,21 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
)
|
||||
else: # fall back to the default config
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
try_get_optimal_moe_lora_config,
|
||||
w1_shape=layer.w13_weight.size(),
|
||||
w2_shape=layer.w2_weight.size(),
|
||||
rank=rank,
|
||||
top_k=top_k,
|
||||
dtype=config_dtype,
|
||||
M=M,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
)
|
||||
shrink_config = get_config_func(M)
|
||||
expand_config = get_config_func(M)
|
||||
shrink_config = get_config_func(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_shrink"
|
||||
)
|
||||
expand_config = get_config_func(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_expand"
|
||||
)
|
||||
shrink_config = self._normalize_keys(shrink_config)
|
||||
expand_config = self._normalize_keys(expand_config)
|
||||
return shrink_config, expand_config
|
||||
|
||||
@@ -7,6 +7,9 @@ from enum import Enum
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
|
||||
class LoRAMappingType(Enum):
|
||||
LANGUAGE = 1
|
||||
@@ -80,3 +83,33 @@ def _fully_sharded_can_replace(can_replace):
|
||||
)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def try_get_optimal_moe_lora_config(
|
||||
op_type: str,
|
||||
w1_shape: tuple[int, ...],
|
||||
w2_shape: tuple[int, ...],
|
||||
rank: int,
|
||||
top_k: int,
|
||||
dtype: str | None,
|
||||
M: int,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int | None]:
|
||||
config = try_get_optimal_moe_config(
|
||||
w1_shape, w2_shape, top_k, dtype, M, block_shape
|
||||
).copy()
|
||||
if op_type in [
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
]:
|
||||
config["BLOCK_SIZE_N"] = min(
|
||||
config.get("BLOCK_SIZE_N", 64), next_power_of_2(rank)
|
||||
)
|
||||
elif op_type in [
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]:
|
||||
config["BLOCK_SIZE_K"] = max(
|
||||
16, min(config.get("BLOCK_SIZE_K", 32), next_power_of_2(rank))
|
||||
)
|
||||
return config
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
logger = init_logger(__name__)
|
||||
is_batch_invariant = vllm_is_batch_invariant()
|
||||
@@ -223,14 +224,25 @@ def get_lora_op_configs(
|
||||
# The default config for fused_moe_lora ops
|
||||
elif op_type in [
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
]:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": min(64, next_power_of_2(rank)),
|
||||
"block_k": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
"group_size_m": 8,
|
||||
"split_k": 1,
|
||||
}
|
||||
elif op_type in [
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": 64,
|
||||
"block_k": 32,
|
||||
"block_k": max(16, min(32, next_power_of_2(rank))),
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
"group_size_m": 8,
|
||||
|
||||
Reference in New Issue
Block a user