Add unpermute-aware fused MoE LoRA path (#32655)
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
This commit is contained in:
@@ -190,6 +190,7 @@ def use_fused_moe_lora_kernel(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
@@ -333,6 +334,189 @@ def test_fused_moe_lora_kernel(
|
||||
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
def use_fused_moe_lora_kernel_naive(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
hidden_states,
|
||||
output,
|
||||
max_loras,
|
||||
block_size,
|
||||
fully_sharded=False,
|
||||
offset=0,
|
||||
):
|
||||
"""
|
||||
Test helper for naive_block_assignment path.
|
||||
Skips moe_lora_align_block_size and uses flattened topk_ids as expert_ids.
|
||||
"""
|
||||
config = {
|
||||
"BLOCK_SIZE_M": block_size,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"NUM_WARPS": 4,
|
||||
"NUM_STAGES": 3,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
|
||||
mul_routed_weight = False
|
||||
|
||||
# In naive mode:
|
||||
# - expert_ids = topk_ids.view(-1), shape: (num_tokens * top_k,)
|
||||
# - sorted_token_ids = None
|
||||
# - num_tokens_post_padded = None
|
||||
expert_ids = topk_ids.reshape(-1)
|
||||
sorted_token_ids = None
|
||||
num_tokens_post_padded = None
|
||||
|
||||
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
hidden_states,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
config["NUM_WARPS"],
|
||||
config["NUM_STAGES"],
|
||||
config["SPLIT_K"],
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
config["NUM_WARPS"],
|
||||
config["NUM_STAGES"],
|
||||
config["SPLIT_K"],
|
||||
mul_routed_weight=mul_routed_weight,
|
||||
fully_sharded=fully_sharded,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("top_k_num", [1, 2])
|
||||
@pytest.mark.parametrize("num_experts", [64, 128])
|
||||
@pytest.mark.parametrize("max_loras", [4, 8])
|
||||
@pytest.mark.parametrize("N", [1408])
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
def test_fused_moe_lora_kernel_naive_block_assignment(
|
||||
num_tokens,
|
||||
top_k_num,
|
||||
num_experts,
|
||||
max_loras,
|
||||
N,
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
dtype,
|
||||
device,
|
||||
seed,
|
||||
):
|
||||
"""
|
||||
Test the naive_block_assignment path of the fused_moe_lora kernel.
|
||||
This path is triggered when batch_size * top_k is much smaller than
|
||||
num_experts * max_loras, and skips the moe_lora_align_block_size kernel.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
set_random_seed(seed)
|
||||
|
||||
# Verify this configuration would trigger naive_block_assignment
|
||||
# (num_tokens * top_k * SPARSITY_FACTOR <= num_experts * max_loras)
|
||||
SPARSITY_FACTOR = 8
|
||||
assert num_tokens * top_k_num * SPARSITY_FACTOR <= num_experts * max_loras, (
|
||||
f"Test configuration doesn't meet naive_block_assignment condition: "
|
||||
f"{num_tokens} * {top_k_num} * {SPARSITY_FACTOR} > {num_experts} * {max_loras}"
|
||||
)
|
||||
|
||||
# the number of randomly generated sentences.
|
||||
num_sequences = min(num_tokens, 4)
|
||||
# generate data
|
||||
topk_ids, topk_weights, token_lora_mapping = sample_data(
|
||||
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
||||
)
|
||||
|
||||
# init lora weights
|
||||
lora_a_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
max_lora_rank,
|
||||
K,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
N,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
(
|
||||
num_tokens,
|
||||
K,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# fused_moe_lora_kernel output (naive path)
|
||||
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
|
||||
use_fused_moe_lora_kernel_naive(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
hidden_states,
|
||||
output,
|
||||
max_loras,
|
||||
block_size,
|
||||
)
|
||||
|
||||
# pytorch reference output
|
||||
output_ref = use_torch(
|
||||
hidden_states,
|
||||
token_lora_mapping,
|
||||
topk_ids,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("num_tokens", [100])
|
||||
@pytest.mark.parametrize("top_k_num", [6])
|
||||
|
||||
Reference in New Issue
Block a user