add support for --fully-sharded-loras in fused_moe (#28761)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack
2025-11-19 00:32:00 -08:00
committed by GitHub
parent ae4821a108
commit d69062c67a
6 changed files with 274 additions and 10 deletions

View File

@@ -1,13 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import pytest
import torch
from tests.utils import multi_gpu_test
from vllm import _custom_ops as ops
from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
@pytest.fixture(autouse=True)
@@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel(
max_loras,
num_experts,
block_size,
fully_sharded=False,
offset=0,
):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
@@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel(
config["NUM_STAGES"],
config["SPLIT_K"],
mul_routed_weight,
fully_sharded=fully_sharded,
offset=offset,
)
return output
def use_torch(
hidden_states,
@@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel(
)
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6])
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("max_loras", [4])
@pytest.mark.parametrize("N", [1408])
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("column_parallel", [True, False])
def test_fused_moe_lora_kernel_fully_sharded(
num_tokens,
top_k_num,
num_experts,
max_loras,
N,
K,
max_lora_rank,
block_size,
dtype,
seed,
column_parallel,
):
current_platform.seed_everything(seed)
# the number of randomly generated sentences.
num_sequences = 10
# generate data
topk_ids, topk_weights, token_lora_mapping = sample_data(
num_tokens, num_sequences, max_loras, num_experts, top_k_num
)
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(
fn,
args=(
nprocs,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
dtype,
seed,
N,
K,
num_tokens,
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
max_loras,
num_experts,
block_size,
column_parallel,
),
nprocs=nprocs,
)
run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2)
def use_fused_moe_lora_kernel_tensor_parallel(
local_rank,
world_size,
init_method,
dtype,
seed,
N,
K,
num_tokens,
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
max_loras,
num_experts,
block_size,
column_parallel,
):
def _get_shard_slice(shard_size):
return slice(local_rank * shard_size, (local_rank + 1) * shard_size)
current_platform.seed_everything(seed)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
init_distributed_environment(
world_size=world_size,
rank=local_rank,
local_rank=local_rank,
distributed_init_method=init_method,
)
initialize_model_parallel(world_size, 1)
tp_size = get_tensor_model_parallel_world_size()
input_dim = K if column_parallel else N
output_dim = N if column_parallel else K
# init lora weights
lora_a = torch.rand(
(
max_loras,
num_experts,
max_lora_rank,
input_dim,
),
dtype=dtype,
)
lora_b = torch.rand(
(
max_loras,
num_experts,
output_dim,
max_lora_rank,
),
dtype=dtype,
)
hidden_states = torch.rand(
(
num_tokens,
input_dim,
),
dtype=dtype,
)
output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype)
topk_ids = topk_ids.to(device)
topk_weights = topk_weights.to(device)
token_lora_mapping = token_lora_mapping.to(device)
ref_output = use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
[lora_a],
[lora_b],
top_k_num,
)
if column_parallel:
# Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim,
# and Lora B is sliced along the output dim
lora_a_shard_size = max_lora_rank // tp_size
lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :]
max_lora_rank = lora_a_shard_size
offset = 0
lora_b_shard_size = output_dim // tp_size
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous()
else:
# Row parallel (e.g. down proj): LoRA A is sliced along the input dim,
# and LoRA B is sliced along the output dim
lora_a_shard_size = input_dim // tp_size
lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)]
hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)]
lora_b_shard_size = output_dim // tp_size
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
offset = lora_b_shard_size * local_rank
use_fused_moe_lora_kernel(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
[lora_a],
[lora_b],
hidden_states,
output,
max_loras,
num_experts,
block_size,
fully_sharded=True,
offset=offset,
)
if column_parallel:
output = tensor_model_parallel_all_gather(output)
else:
output = tensor_model_parallel_all_reduce(output)
torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)