[Feature][Kernel]FusedMoE LoRA (#21229)

Signed-off-by: wuchen <cntryroa@gmail.com>
Signed-off-by: banjuede <lmklhc@163.com>
Signed-off-by: Chen Wu <cntryroa@gmail.com>
Signed-off-by: Danielle Robinson <dmmaddix@amazon.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: wuchen <wuchen@zetyun.com>
Co-authored-by: Nathan Van Gheem <vangheem@gmail.com>
Co-authored-by: banjuede <lmklhc@163.com>
Co-authored-by: Danielle Robinson <dmmaddix@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
Chen Wu
2025-10-21 11:01:37 +08:00
committed by GitHub
parent 3ada34f9cb
commit 5f6cbf60d6
28 changed files with 2084 additions and 55 deletions

View File

@@ -384,7 +384,12 @@ steps:
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
--ignore=lora/test_chatglm3_tp.py \ --ignore=lora/test_chatglm3_tp.py \
--ignore=lora/test_llama_tp.py \ --ignore=lora/test_llama_tp.py \
--ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_llm_with_multi_loras.py \
--ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \
--ignore=lora/test_qwen3moe_tp.py
parallelism: 4 parallelism: 4
- label: PyTorch Compilation Unit Tests # 15min - label: PyTorch Compilation Unit Tests # 15min
@@ -1065,6 +1070,7 @@ steps:
- pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min

View File

@@ -883,6 +883,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_lora_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")

View File

@@ -0,0 +1,173 @@
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
return row * total_col + col;
}
} // namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_lora_align_sum_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_id = blockIdx.x;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
}
// Initialize expert_ids with -1
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
expert_ids[lora_id * max_num_m_blocks + it] = -1;
}
// Initialize total_tokens_post_pad with 0
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
}
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int mask = token_lora_mapping[i / topk_num] == lora_id;
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
tokens_cnts[idx] += mask;
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
threadIdx.x;
}
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
atomicAdd(
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
(i - numel) * mask);
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
}
}
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
const int topk_num = topk_ids.size(1);
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
max_num_tokens_padded = round_to_next_multiple_of(
max_num_tokens_padded, static_cast<int>(block_size));
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
"num_thread must be less than 1024, "
"and fallback is not implemented yet.");
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet.");
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
dim3 blockDim(num_thread);
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>());
});
}

View File

@@ -20,6 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
torch::Tensor expert_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales, torch::Tensor b_qweight, torch::Tensor b_scales,

View File

@@ -33,6 +33,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.impl("batched_moe_align_block_size", torch::kCUDA, m.impl("batched_moe_align_block_size", torch::kCUDA,
&batched_moe_align_block_size); &batched_moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_lora_align_block_size(Tensor topk_ids,"
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
m.def( m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "

View File

@@ -230,6 +230,26 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
@pytest.fixture(scope="session")
def deepseekv2_lora_files():
return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA")
@pytest.fixture(scope="session")
def gptoss20b_lora_files():
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")
@pytest.fixture(scope="session")
def qwen3moe_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider")
@pytest.fixture(scope="session")
def olmoe_lora_files():
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
@pytest.fixture @pytest.fixture
def reset_default_device(): def reset_default_device():
""" """

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "deepseek-ai/DeepSeek-V2-Lite-Chat"
PROMPT_TEMPLATE = "<begin▁of▁sentence>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int):
prompts = [
PROMPT_TEMPLATE.format(context="Who are you?"),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# return generated_texts
expected_lora_output = [
"I am \u5f20\u5b50\u8c6a, an AI assistant developed by \u9648\u58eb\u680b.", # noqa: E501
]
for i in range(len(expected_lora_output)):
assert generated_texts[i].startswith(expected_lora_output[i])
def test_deepseekv2_lora(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, deepseekv2_lora_files, 1)
def test_deepseekv2(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
)
generate_and_test(llm, deepseekv2_lora_files, 1)
@multi_gpu_test(num_gpus=2)
def test_deepseekv2_tp2(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
tensor_parallel_size=2,
)
generate_and_test(llm, deepseekv2_lora_files, 2)
@multi_gpu_test(num_gpus=4)
def test_deepseekv2_tp4(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
tensor_parallel_size=4,
)
generate_and_test(llm, deepseekv2_lora_files, 2)

View File

@@ -0,0 +1,287 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def reset_device(reset_default_device):
pass
def round_up(x, base):
return ((x + base - 1) // base) * base
def CEILDIV(x, y):
return (x + y - 1) // y
def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int):
"""
Split `num_tokens` into `num_sequences` sequences.
Each sequence randomly selects 1 LoRA index from [0, max_loras),
and all tokens in that sequence are assigned this LoRA index.
Args:
num_tokens (int): Total number of tokens.
num_sequences (int): Number of sequences to split the tokens into.
max_loras (int): Total number of available LoRA modules.
Returns:
torch.Tensor: 1D tensor of shape [num_tokens], where each value
is the LoRA index assigned to that token.
"""
assert num_sequences > 0 and max_loras > 0
assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences"
# Compute token distribution per sequence (distribute remainder evenly)
tokens_per_seq = num_tokens // num_sequences
remainder = num_tokens % num_sequences
token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32)
start = 0
for seq_idx in range(num_sequences):
# Determine the token range for this sequence
end = start + tokens_per_seq + (1 if seq_idx < remainder else 0)
# Randomly select one LoRA ID for this sequence
lora_id = random.randint(0, max_loras - 1)
# Assign the same LoRA ID to all tokens in this sequence
token_lora_mapping[start:end] = lora_id
start = end
return token_lora_mapping
def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int):
"""
For each token, randomly select `top_k_num` distinct experts out of `num_experts`,
and assign normalized random weights that sum to 1.
Args:
num_tokens (int): Total number of tokens.
num_experts (int): Total number of available experts.
top_k_num (int): Number of experts to select per token.
Returns:
expert_indices (torch.Tensor): shape [num_tokens, top_k_num],
expert index for each token.
expert_weights (torch.Tensor): shape [num_tokens, top_k_num],
normalized weights (sum = 1 per row).
"""
assert top_k_num <= num_experts, "top_k_num must be <= num_experts"
# Randomly select top_k_num distinct experts for each token
expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32)
for i in range(num_tokens):
# Randomly choose unique expert indices
selected = torch.randperm(num_experts)[:top_k_num]
expert_indices[i] = selected
# Generate random weights and normalize along dim=1
expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32)
expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True)
return expert_indices, expert_weights
def sample_data(
num_tokens: int,
num_sequences: int,
max_loras: int,
num_experts: int,
top_k_num: int,
):
topk_ids, topk_weights = assign_experts_to_tokens(
num_tokens, num_experts, top_k_num
)
token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras)
return topk_ids, topk_weights, token_lora_mapping
def use_fused_moe_lora_kernel(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_a_stacked,
lora_b_stacked,
hidden_states,
output,
max_loras,
num_experts,
block_size,
):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
# init output tensors
sorted_token_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
)
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
# call kernel
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
)
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
mul_routed_weight = False
expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1)
fused_moe_lora(
output,
hidden_states,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_lora_rank,
top_k_num,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
mul_routed_weight,
)
return output
def use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
lora_a_stacked,
lora_b_stacked,
top_k_num,
):
outputs = []
for i in range(hidden_states.shape[0]):
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
lora_a = lora_a_stacked[0][lora_idx][expert_ids]
lora_b = lora_b_stacked[0][lora_idx][expert_ids]
tensors = [
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
]
outputs.append(torch.stack(tensors, dim=0))
return torch.stack(outputs, dim=0)
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6, 12])
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("max_loras", [4, 6, 16])
@pytest.mark.parametrize("N", [1408])
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
def test_fused_moe_lora_kernel(
num_tokens,
top_k_num,
num_experts,
max_loras,
N,
K,
max_lora_rank,
block_size,
):
torch.set_default_device("cuda:0")
current_platform.seed_everything(42)
# the number of randomly generated sentences.
num_sequences = 10
# generate data
topk_ids, topk_weights, token_lora_mapping = sample_data(
num_tokens, num_sequences, max_loras, num_experts, top_k_num
)
# init lora weights
lora_a_stacked = [
torch.rand(
(
max_loras,
num_experts,
max_lora_rank,
K,
),
dtype=torch.bfloat16,
)
]
lora_b_stacked = [
torch.rand(
(
max_loras,
num_experts,
N,
max_lora_rank,
),
dtype=torch.bfloat16,
)
]
hidden_states = torch.rand(
(
num_tokens,
K,
),
dtype=torch.bfloat16,
)
# fused_moe_lora_kernel output
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
use_fused_moe_lora_kernel(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_a_stacked,
lora_b_stacked,
hidden_states,
output,
max_loras,
num_experts,
block_size,
)
# pytorch output
output2 = use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
lora_a_stacked,
lora_b_stacked,
top_k_num,
)
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)

52
tests/lora/test_gptoss.py Normal file
View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "openai/gpt-oss-20b"
PROMPT_TEMPLATE = "<begin▁of▁sentence>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(context="Who are you?"),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
# FIXME: Load gpt-oss adapter
def test_gptoss20b_lora(gptoss20b_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_loras=1,
trust_remote_code=True,
)
expected_lora_output = [
"I am an AI language model developed by OpenAI. "
"I am here to help you with any questions or "
"tasks you may have."
]
output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1)
print(output1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])

View File

@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm import _custom_ops as ops
def round_up(x, base):
return ((x + base - 1) // base) * base
def CEILDIV(x, y):
return (x + y - 1) // y
def sample_data(num_experts, max_loras, num_tokens, topk_num):
topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32)
token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32)
for i in range(num_tokens):
pool = list(range(num_experts))
random.shuffle(pool)
for j in range(topk_num):
topk_ids[i, j] = pool[j]
token_lora_mapping[i] = random.randint(0, max_loras - 1)
return topk_ids.to("cuda"), token_lora_mapping.to("cuda")
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
@pytest.mark.parametrize("topk_num", [6])
@pytest.mark.parametrize("num_experts", [64, 128])
@pytest.mark.parametrize("max_loras", [2, 32])
@pytest.mark.parametrize("block_size", [16])
def test_moe_lora_align_block_size(
num_tokens, topk_num, num_experts, max_loras, block_size
):
# sample data
random.seed(1)
topk_ids, token_lora_mapping = sample_data(
num_experts, max_loras, num_tokens, topk_num
)
# compute paddings
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
# init output tensors
sorted_token_ids = torch.full(
(max_loras * max_num_tokens_padded,),
topk_ids.numel(),
dtype=torch.int32,
device="cuda",
)
expert_ids = torch.full(
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
)
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
# call kernel
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
)
# verify values
expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size)
for lora_idx in range(max_loras):
for token_idx in range(sorted_token_ids.size(1)):
block = sorted_token_ids[lora_idx][token_idx]
indices = block[block != topk_ids.numel()]
if indices.numel() > 0:
expert_id = expert_ids[lora_idx][token_idx]
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
if __name__ == "__main__":
pytest.main([__file__])

109
tests/lora/test_olmoe_tp.py Normal file
View File

@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "allenai/OLMoE-1B-7B-0125-Instruct"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
The People_ID of candidate is the foreign key of People_ID of people.
###Input:
{context}
###Response:""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM candidate",
"SELECT count(*) FROM candidate",
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
PROMPT_TEMPLATE.format(
context="Which poll resource provided the most number of candidate information?" # noqa: E501
),
PROMPT_TEMPLATE.format(
context="Return the poll resource associated with the most candidates."
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_olmoe_lora(olmoe_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=2,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=4)
def test_olmoe_lora_tp4(olmoe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=4,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "Qwen/Qwen3-30B-A3B"
PROMPT_TEMPLATE = """<|im_start|>user
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
The People_ID of candidate is the foreign key of People_ID of people.
###Input:
{context}
###Response:<|im_end|>
<|im_start|>assistant""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"<think>\n\n</think>\n\nSELECT count(*) FROM candidate",
"<think>\n\n</think>\n\nSELECT count(*) FROM candidate",
"<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
"<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
PROMPT_TEMPLATE.format(
context="Which poll resource provided the most number of candidate information?" # noqa: E501
),
PROMPT_TEMPLATE.format(
context="Return the poll resource associated with the most candidates."
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_qwen3moe_lora(qwen3moe_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
def test_qwen3moe_lora_tp2(qwen3moe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=2,
)
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=4)
def test_qwen3moe_lora_tp4(qwen3moe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=4,
)
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)

View File

@@ -1795,6 +1795,28 @@ def batched_moe_align_block_size(
) )
def moe_lora_align_block_size(
topk_ids: torch.Tensor,
token_lora_mapping: torch.Tensor,
num_experts: int,
block_size: int,
max_loras: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
torch.ops._moe_C.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
)
def moe_wna16_gemm( def moe_wna16_gemm(
input: torch.Tensor, input: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,

View File

@@ -11,6 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
) )
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import ( from vllm.lora.layers.row_parallel_linear import (
@@ -36,4 +37,5 @@ __all__ = [
"RowParallelLinearWithShardedLoRA", "RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA", "ReplicatedLinearWithLoRA",
"LoRAMapping", "LoRAMapping",
"FusedMoEWithLoRA",
] ]

View File

@@ -0,0 +1,410 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
_get_config_dtype_str,
mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
self._inject_lora_into_fused_moe()
def _inject_lora_into_fused_moe(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
if self.base_layer.quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
quant_config = self.base_layer.quant_config
else:
quant_config = mxfp4_w4a16_moe_quant_config(
w1_bias=self.base_layer.w13_bias,
w2_bias=self.base_layer.w2_bias,
w1_scale=self.base_layer.w13_weight_scale,
w2_scale=self.base_layer.w2_weight_scale,
)
m_fused_moe_fn = (
modular_triton_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
if not quant_config.use_mxfp4_w4a16
else modular_marlin_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
]
moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
result = func(*args, **kwargs)
return result
return wrapper
def act_decorator(layer, func):
def wrapper(*args, **kwargs):
_, output, input = args
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
global_num_experts = moe_state_dict["global_num_experts"]
expert_map = moe_state_dict["expert_map"]
max_loras = moe_state_dict["max_loras"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
config = get_config_func(M)
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.punica_wrapper.moe_lora_align_block_size(
curr_topk_ids,
num_tokens,
config["BLOCK_SIZE_M"],
global_num_experts,
max_loras,
expert_map,
)
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
moe_state_dict["expert_ids_lora"] = expert_ids_lora
moe_state_dict["num_tokens_post_padded_lora"] = (
num_tokens_post_padded_lora
)
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
w13_lora_a_stacked,
w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
config,
)
result = func(*args, **kwargs)
moe_state_dict["intermediate_cache2"] = output
return result
return wrapper
def moe_sum_decorator(layer, func):
def wrapper(*args, **kwargs):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
max_loras = moe_state_dict["max_loras"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
config = get_config_func(M)
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
expert_ids_lora = moe_state_dict["expert_ids_lora"]
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
[self.w2_lora_a_stacked],
[self.w2_lora_b_stacked],
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
config,
True,
)
result = func(*args, **kwargs)
return result
return wrapper
fused_experts = m_fused_moe_fn.fused_experts
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
)
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
self.base_layer.quant_method.old_fused_experts = (
self.base_layer.quant_method.fused_experts
)
self.base_layer.quant_method.fused_experts = m_fused_moe_fn
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.w1_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w1_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w2_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w2_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.hidden_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w3_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w3_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.global_num_experts):
# gate_proj,down_proj,up_proj
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
self.w1_lora_a_stacked[index] = 0
self.w1_lora_b_stacked[index] = 0
self.w3_lora_a_stacked[index] = 0
self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None,
):
"""Overwrites lora tensors at index."""
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
w3_lora_a = lora_a[eid * 3 + 2]
w1_lora_b = lora_b[eid * 3]
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
if self.tp_size > 1:
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
self.w1_lora_a_stacked[
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
self.w3_lora_a_stacked[
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
self.w2_lora_b_stacked[
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True)
self.w1_lora_b_stacked[
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
self.w3_lora_b_stacked[
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return isinstance(source_layer, FusedMoE)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def _shared_experts(self):
return self.base_layer._shared_experts
@property
def quant_method(self):
return self.base_layer.quant_method

View File

@@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import get_punica_wrapper
@@ -23,15 +23,14 @@ from vllm.lora.utils import (
get_supported_lora_modules, get_supported_lora_modules,
is_regex_target_modules, is_regex_target_modules,
parse_fine_tuned_lora_name, parse_fine_tuned_lora_name,
process_packed_modules_mapping,
replace_submodule, replace_submodule,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.utils.cache import LRUCache from vllm.utils.cache import LRUCache
@@ -60,18 +59,6 @@ def get_lora_id():
return _GLOBAL_LORA_ID return _GLOBAL_LORA_ID
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.warning_once(
"For MoE models, vLLM currently does not support fused MoE LoRA "
"inference. Please ensure that the loaded LoRA model does not "
"contain expert weights."
)
return True
return False
class LoRAModel: class LoRAModel:
"""A LoRA fine-tuned model.""" """A LoRA fine-tuned model."""
@@ -229,9 +216,19 @@ class LoRAModel:
def check_unexpected_modules(modules: dict): def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa for lora_module in modules.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
part_name = module_name.split(".")[-1] # Handle FSDP file format where experts.base_layer is the
if part_name not in expected_lora_modules: # gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
# Case for expert lora weights
if ".experts" in module_name:
if not any(
module_name.endswith(ele) for ele in expected_lora_modules
):
unexpected_modules.append(module_name)
elif module_name.split(".")[-1] not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
if unexpected_modules: if unexpected_modules:
raise ValueError( raise ValueError(
f"While loading {lora_dir}, expected" f"While loading {lora_dir}, expected"
@@ -371,7 +368,7 @@ class LoRAModelManager:
assert self.supported_lora_modules, "No supported LoRA modules found in" assert self.supported_lora_modules, "No supported LoRA modules found in"
f" {self.model.__class__.__name__}." f" {self.model.__class__.__name__}."
self.packed_modules_mapping = get_packed_modules_mapping(self.model) self.packed_modules_mapping = process_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model # Used to indicate whether the model is a multimodal model
self.supports_mm: bool = ( self.supports_mm: bool = (
supports_multimodal(self.model) supports_multimodal(self.model)
@@ -380,7 +377,6 @@ class LoRAModelManager:
and hasattr(self.model, "get_mm_mapping") and hasattr(self.model, "get_mm_mapping")
) )
self.is_pooling_model = is_pooling_model(self.model) self.is_pooling_model = is_pooling_model(self.model)
self.is_moe_model = is_moe_model(self.model)
self.packed_modules: dict[str, list[str]] = {} self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache. # Dict instead of a set for compatibility with LRUCache.
@@ -431,6 +427,50 @@ class LoRAModelManager:
module_lora = self._get_lora_layer_weights(lora_model, module_name) module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora: if module_lora:
module_lora.optimize() module_lora.optimize()
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
module_lora.lora_a
):
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
gate_up_proj_lora = self._get_lora_layer_weights(
lora_model, module_name + ".base_layer"
)
assert gate_up_proj_lora is not None
assert module_lora is not None
down_proj_lora = module_lora
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
num_experts, dim=-1
)
up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
num_experts, dim=-1
)
down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)
lora_a = []
lora_b = []
for i in range(num_experts):
lora_a.append(gate_proj_a[i])
lora_a.append(down_proj_a[i])
lora_a.append(up_proj_a[i])
lora_b.append(gate_proj_b[i])
lora_b.append(down_proj_b[i])
lora_b.append(up_proj_b[i])
module_lora.lora_a = lora_a
module_lora.lora_b = lora_b
module.set_lora( module.set_lora(
index, index,
module_lora.lora_a, module_lora.lora_a,
@@ -486,6 +526,7 @@ class LoRAModelManager:
for module_name, module in self.model.named_modules(remove_duplicate=False): for module_name, module in self.model.named_modules(remove_duplicate=False):
if isinstance(module, PPMissingLayer): if isinstance(module, PPMissingLayer):
continue continue
if not self._match_target_modules(module_name): if not self._match_target_modules(module_name):
continue continue
# A temporary approach for multimodal models to support LoRA # A temporary approach for multimodal models to support LoRA
@@ -549,7 +590,10 @@ class LoRAModelManager:
new_module.set_mapping(self.punica_wrapper) new_module.set_mapping(self.punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA) assert isinstance(module, BaseLayerWithLoRA), (
f"Module {module_name} must be a BaseLayerWithLoRA instance,"
)
f" got {type(module)}"
self.modules[module_name] = module self.modules[module_name] = module
def create_dummy_lora( def create_dummy_lora(

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
@@ -9,4 +10,5 @@ __all__ = [
"lora_expand", "lora_expand",
"lora_shrink", "lora_shrink",
"LoRAKernelMeta", "LoRAKernelMeta",
"fused_moe_lora",
] ]

View File

@@ -0,0 +1,350 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import triton
import triton.language as tl
from vllm.utils.torch_utils import direct_register_custom_op
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
return ptr_tensor
tensor_ptrs = []
for lora_weight in lora_weights:
tensor_ptrs.append(lora_weight.data_ptr())
ptr_tensor = torch.tensor(tensor_ptrs, device=device)
_LORA_PTR_DICT[key] = ptr_tensor
return _LORA_PTR_DICT.get(key)
@triton.jit
def _fused_moe_lora_kernel(
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
num_experts,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_bl,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_tl,
stride_el,
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
slice_a_size: tl.constexpr,
slice_c_size: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
max_loras = tl.num_programs(axis=2)
# calculate pid_m,pid_n
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
# get the expert_id to process curr shard
ind = lora_idx * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind)
if expert_id == -1:
return
# get a_ptr,b_ptr,c_ptr
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0
)
token_mask = offs_token < num_valid_tokens
# get a_ptrs,b_ptrs
a_ptrs = cur_a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = (
cur_b_ptr
+ lora_idx * stride_bl
+ expert_id * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(tl.bfloat16)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@torch.inference_mode()
def _fused_moe_lora(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
lora_a_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, N, max_lora_rank,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
max_lora_rank: int,
top_k_num: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
mul_routed_weight: bool = False,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert (
sorted_token_ids.dim()
== expert_ids.dim()
== topk_weights.dim()
== qcurr_hidden_states.dim()
== 2
)
assert (
sorted_token_ids.shape[0]
== expert_ids.shape[0]
== num_tokens_post_padded.shape[0]
)
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked)
config = {
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
}
w1_lora_a_stacked = lora_a_stacked[0]
w1_lora_b_stacked = lora_b_stacked[0]
num_experts = lora_a_stacked[0].shape[1]
N = max_lora_rank
M = topk_weights.shape[0]
EM = sorted_token_ids.shape[1]
K = qcurr_hidden_states.shape[1]
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
lora_intermediate_cache1 = torch.zeros(
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
dtype=torch.bfloat16,
device=device,
)
# slices
a_intermediate_size = num_slices * M * top_k_num * max_lora_rank
a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view(
num_slices, M, top_k_num, max_lora_rank
)
b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view(
num_slices, M, top_k_num, w1_output_dim_size
)
b_ptr = _get_ptr(lora_a_stacked, device)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
lora_a_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
b_ptr,
a_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2),
a_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=1,
num_slice_c=num_slices,
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
**config,
)
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
N = w1_output_dim_size
# a_intermediate_cache1 = a_intermediate_cache1.view(
# M, -1, a_intermediate_cache1.shape[3]
# )
a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
lora_b_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
b_ptr,
b_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
w1_lora_b_stacked.stride(1),
w1_lora_b_stacked.stride(3),
w1_lora_b_stacked.stride(2),
b_intermediate_cache1.stride(2),
b_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=num_slices,
num_slice_c=num_slices,
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=b_intermediate_cache1.numel() // num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
**config,
)
for i in range(num_slices):
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
def _fused_moe_lora_fake(
output: torch.Tensor,
qcurr_hidden_states: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
mul_routed_weight: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="fused_moe_lora",
op_func=_fused_moe_lora,
mutates_args=["output"],
fake_impl=_fused_moe_lora_fake,
)
fused_moe_lora = torch.ops.vllm.fused_moe_lora
except AttributeError:
fused_moe_lora = _fused_moe_lora

View File

@@ -448,3 +448,42 @@ class PunicaWrapperBase(PunicaWrapperABC):
""" """
# TODO: implement it based on torch ops # TODO: implement it based on torch ops
raise NotImplementedError raise NotImplementedError
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
config,
mul_routed_weight=False,
):
"""
Performs a fused forward computation for LoRA of
Mixture-of-Experts (MoE) layer.
"""
# TODO: implement it based on torch ops
raise NotImplementedError

View File

@@ -12,10 +12,18 @@ from typing import final
import torch import torch
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils import round_up
if HAS_TRITON: if HAS_TRITON:
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink from vllm.lora.ops.triton_ops import (
LoRAKernelMeta,
fused_moe_lora,
lora_expand,
lora_shrink,
)
from vllm import _custom_ops as ops
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
@@ -289,3 +297,91 @@ class PunicaWrapperGPU(PunicaWrapperBase):
add_inputs=True, add_inputs=True,
) )
y = y.view_as(y_org) y = y.view_as(y_org)
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
(token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args(
num_tokens
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
config,
mul_routed_weight=False,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
fused_moe_lora(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_lora_rank,
top_k_num,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
mul_routed_weight,
)

View File

@@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA, BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
@@ -35,7 +36,9 @@ from vllm.lora.layers import (
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA, VocabParallelEmbeddingWithLoRA,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -58,9 +61,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
} }
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
return True
return False
def from_layer( def from_layer(
layer: nn.Module, layer: nn.Module,
max_loras: int, max_loras: int,
@@ -205,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
if isinstance(module, (LinearBase,)): if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1]) supported_lora_modules.add(name.split(".")[-1])
if isinstance(module, (FusedMoE,)):
supported_lora_modules.add(name.split(".")[-1])
return list(supported_lora_modules) return list(supported_lora_modules)
@@ -252,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str:
return lora_path return lora_path
return local_snapshot_path return local_snapshot_path
def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
if is_moe_model(model):
if moe_packed_mapping := get_moe_expert_mapping(model):
# This method generates and returns a dictionary mapping packed module
# names to lists of their corresponding submodule names. It includes
# both static mappings and dynamic mappings for expert layers, where
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
packed_modules_mapping["experts"] = [
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
]
return packed_modules_mapping
else:
raise AttributeError(
"To support LoRA for MoE model, "
"'get_expert_mapping' must be implemented"
)
else:
return get_packed_modules_mapping(model)

View File

@@ -94,7 +94,8 @@ class WorkerLoRAManager:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_modules.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_modules.append(module)
if module == "experts":
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules)) expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)

View File

@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ.""" """Fused MoE utilities for GPTQ."""
from collections.abc import Callable
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
@@ -11,6 +13,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size, batched_moe_align_block_size,
moe_align_block_size, moe_align_block_size,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
@@ -24,6 +29,21 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
def default_activation_func(
activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
def _fused_marlin_moe( def _fused_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@@ -36,12 +56,15 @@ def _fused_marlin_moe(
num_topk: int, num_topk: int,
quant_type: ScalarType, quant_type: ScalarType,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
activation: str,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
block_size_m: int, block_size_m: int,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None,
@@ -118,20 +141,9 @@ def _fused_marlin_moe(
is_zp_float=False, is_zp_float=False,
) )
if activation == "silu": activation_func(
torch.ops._C.silu_and_mul( activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
intermediate_cache2, intermediate_cache1.view(-1, 2 * N) )
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
if output is None: if output is None:
output = intermediate_cache3 output = intermediate_cache3
@@ -185,7 +197,11 @@ def fused_marlin_moe(
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
activation: str | None = "silu", activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
@@ -290,12 +306,13 @@ def fused_marlin_moe(
num_topk=topk, num_topk=topk,
quant_type=quant_type, quant_type=quant_type,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
expert_map=expert_map, expert_map=expert_map,
block_size_m=block_size_m, block_size_m=block_size_m,
sorted_token_ids=sorted_token_ids, sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids, expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded, num_tokens_post_padded=num_tokens_post_padded,
activation=activation,
activation_func=activation_func,
global_scale1=global_scale1, global_scale1=global_scale1,
global_scale2=global_scale2, global_scale2=global_scale2,
g_idx1=g_idx1, g_idx1=g_idx1,
@@ -317,7 +334,10 @@ def fused_marlin_moe(
else: else:
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output) if moe_sum is None:
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
else:
return moe_sum(moe_output, output)
def batched_fused_marlin_moe( def batched_fused_marlin_moe(
@@ -600,6 +620,8 @@ class MarlinExperts(MarlinExpertsBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
activation=activation, activation=activation,
activation_func=self.activation,
moe_sum=self.moe_sum,
expert_map=expert_map, expert_map=expert_map,
output=output, output=output,
# Workspaces are swapped in workspace_shapes() to account for proper # Workspaces are swapped in workspace_shapes() to account for proper
@@ -608,6 +630,19 @@ class MarlinExperts(MarlinExpertsBase):
intermediate_cache2=workspace13, intermediate_cache2=workspace13,
) )
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_marlin_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config),
shared_experts,
)
class BatchedMarlinExperts(MarlinExpertsBase): class BatchedMarlinExperts(MarlinExpertsBase):
def __init__( def __init__(

View File

@@ -2135,13 +2135,18 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
B_bias=self.w2_bias, B_bias=self.w2_bias,
) )
ops.moe_sum(intermediate_cache3, output) # separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_triton_fused_moe( def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel( return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config), TritonExperts(quant_config),
shared_experts,
) )

View File

@@ -557,6 +557,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
torch.ops._C.silu_and_mul(output, input) torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu": elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input) torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")

View File

@@ -1313,6 +1313,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
logits = self.logits_processor(self.lm_head, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)

View File

@@ -32,7 +32,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
from .interfaces import SupportsEagle3, SupportsPP from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
@@ -627,7 +627,7 @@ class GptOssModel(nn.Module):
) )
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
@@ -696,6 +696,17 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
logits = self.logits_processor(self.lm_head, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, weight scales, activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_local_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,

View File

@@ -49,7 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
@@ -349,8 +349,6 @@ class OlmoeModel(nn.Module):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
@@ -433,17 +431,13 @@ class OlmoeModel(nn.Module):
return loaded_params return loaded_params
class OlmoeForCausalLM(nn.Module, SupportsPP): class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
"k_proj", "k_proj",
"v_proj", "v_proj",
], ]
"gate_up_proj": [
"gate_proj",
"up_proj",
],
} }
def __init__( def __init__(