[LoRA][Perf] Improve FusedMoE LoRA performance for small rank (#32019)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-01-10 11:04:18 -08:00
committed by GitHub
parent b8bf5c45bb
commit 543c23be78
3 changed files with 61 additions and 11 deletions

View File

@@ -24,7 +24,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
@@ -39,7 +38,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from .utils import _get_lora_device
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
class FusedMoEWithLoRA(BaseLayerWithLoRA):
@@ -103,15 +102,21 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
)
else: # fall back to the default config
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
try_get_optimal_moe_lora_config,
w1_shape=layer.w13_weight.size(),
w2_shape=layer.w2_weight.size(),
rank=rank,
top_k=top_k,
dtype=config_dtype,
M=M,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
shrink_config = get_config_func(M)
expand_config = get_config_func(M)
shrink_config = get_config_func(
op_type=f"fused_moe_lora_{op_prefix}_shrink"
)
expand_config = get_config_func(
op_type=f"fused_moe_lora_{op_prefix}_expand"
)
shrink_config = self._normalize_keys(shrink_config)
expand_config = self._normalize_keys(expand_config)
return shrink_config, expand_config

View File

@@ -7,6 +7,9 @@ from enum import Enum
import torch
import torch.nn as nn
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.utils.math_utils import next_power_of_2
class LoRAMappingType(Enum):
LANGUAGE = 1
@@ -80,3 +83,33 @@ def _fully_sharded_can_replace(can_replace):
)
return dec
def try_get_optimal_moe_lora_config(
op_type: str,
w1_shape: tuple[int, ...],
w2_shape: tuple[int, ...],
rank: int,
top_k: int,
dtype: str | None,
M: int,
block_shape: list[int] | None = None,
) -> dict[str, int | None]:
config = try_get_optimal_moe_config(
w1_shape, w2_shape, top_k, dtype, M, block_shape
).copy()
if op_type in [
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w2_shrink",
]:
config["BLOCK_SIZE_N"] = min(
config.get("BLOCK_SIZE_N", 64), next_power_of_2(rank)
)
elif op_type in [
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_expand",
]:
config["BLOCK_SIZE_K"] = max(
16, min(config.get("BLOCK_SIZE_K", 32), next_power_of_2(rank))
)
return config

View File

@@ -13,6 +13,7 @@ from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
@@ -223,14 +224,25 @@ def get_lora_op_configs(
# The default config for fused_moe_lora ops
elif op_type in [
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_shrink",
]:
default = {
"block_m": 64,
"block_n": min(64, next_power_of_2(rank)),
"block_k": 32,
"num_warps": 4,
"num_stages": 3,
"group_size_m": 8,
"split_k": 1,
}
elif op_type in [
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_expand",
]:
default = {
"block_m": 64,
"block_n": 64,
"block_k": 32,
"block_k": max(16, min(32, next_power_of_2(rank))),
"num_warps": 4,
"num_stages": 3,
"group_size_m": 8,