add support for --fully-sharded-loras in fused_moe (#28761)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -1,13 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.lora.ops.triton_ops import fused_moe_lora
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel(
|
||||
max_loras,
|
||||
num_experts,
|
||||
block_size,
|
||||
fully_sharded=False,
|
||||
offset=0,
|
||||
):
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
@@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel(
|
||||
config["NUM_STAGES"],
|
||||
config["SPLIT_K"],
|
||||
mul_routed_weight,
|
||||
fully_sharded=fully_sharded,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def use_torch(
|
||||
hidden_states,
|
||||
@@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel(
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("num_tokens", [100])
|
||||
@pytest.mark.parametrize("top_k_num", [6])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("max_loras", [4])
|
||||
@pytest.mark.parametrize("N", [1408])
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("column_parallel", [True, False])
|
||||
def test_fused_moe_lora_kernel_fully_sharded(
|
||||
num_tokens,
|
||||
top_k_num,
|
||||
num_experts,
|
||||
max_loras,
|
||||
N,
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
dtype,
|
||||
seed,
|
||||
column_parallel,
|
||||
):
|
||||
current_platform.seed_everything(seed)
|
||||
# the number of randomly generated sentences.
|
||||
num_sequences = 10
|
||||
# generate data
|
||||
topk_ids, topk_weights, token_lora_mapping = sample_data(
|
||||
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
||||
)
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
nprocs,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
dtype,
|
||||
seed,
|
||||
N,
|
||||
K,
|
||||
num_tokens,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
max_loras,
|
||||
num_experts,
|
||||
block_size,
|
||||
column_parallel,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2)
|
||||
|
||||
|
||||
def use_fused_moe_lora_kernel_tensor_parallel(
|
||||
local_rank,
|
||||
world_size,
|
||||
init_method,
|
||||
dtype,
|
||||
seed,
|
||||
N,
|
||||
K,
|
||||
num_tokens,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
max_loras,
|
||||
num_experts,
|
||||
block_size,
|
||||
column_parallel,
|
||||
):
|
||||
def _get_shard_slice(shard_size):
|
||||
return slice(local_rank * shard_size, (local_rank + 1) * shard_size)
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=local_rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=init_method,
|
||||
)
|
||||
initialize_model_parallel(world_size, 1)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
input_dim = K if column_parallel else N
|
||||
output_dim = N if column_parallel else K
|
||||
|
||||
# init lora weights
|
||||
lora_a = torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
max_lora_rank,
|
||||
input_dim,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
lora_b = torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
output_dim,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
hidden_states = torch.rand(
|
||||
(
|
||||
num_tokens,
|
||||
input_dim,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype)
|
||||
topk_ids = topk_ids.to(device)
|
||||
topk_weights = topk_weights.to(device)
|
||||
token_lora_mapping = token_lora_mapping.to(device)
|
||||
|
||||
ref_output = use_torch(
|
||||
hidden_states,
|
||||
token_lora_mapping,
|
||||
topk_ids,
|
||||
[lora_a],
|
||||
[lora_b],
|
||||
top_k_num,
|
||||
)
|
||||
|
||||
if column_parallel:
|
||||
# Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim,
|
||||
# and Lora B is sliced along the output dim
|
||||
lora_a_shard_size = max_lora_rank // tp_size
|
||||
lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :]
|
||||
max_lora_rank = lora_a_shard_size
|
||||
offset = 0
|
||||
|
||||
lora_b_shard_size = output_dim // tp_size
|
||||
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
|
||||
output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous()
|
||||
else:
|
||||
# Row parallel (e.g. down proj): LoRA A is sliced along the input dim,
|
||||
# and LoRA B is sliced along the output dim
|
||||
lora_a_shard_size = input_dim // tp_size
|
||||
lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)]
|
||||
hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)]
|
||||
|
||||
lora_b_shard_size = output_dim // tp_size
|
||||
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
|
||||
offset = lora_b_shard_size * local_rank
|
||||
|
||||
use_fused_moe_lora_kernel(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
[lora_a],
|
||||
[lora_b],
|
||||
hidden_states,
|
||||
output,
|
||||
max_loras,
|
||||
num_experts,
|
||||
block_size,
|
||||
fully_sharded=True,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
if column_parallel:
|
||||
output = tensor_model_parallel_all_gather(output)
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(output)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)
|
||||
|
||||
Reference in New Issue
Block a user