[Bugfix] Fix lora_ids in FusedMoE LoRA test (#35135)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -118,7 +118,10 @@ def sample_data(
|
|||||||
num_tokens, num_experts, top_k_num
|
num_tokens, num_experts, top_k_num
|
||||||
)
|
)
|
||||||
token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras)
|
token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras)
|
||||||
return topk_ids, topk_weights, token_lora_mapping
|
active_lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32)
|
||||||
|
lora_ids = torch.unique(token_lora_mapping, sorted=True)
|
||||||
|
active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True)
|
||||||
|
return topk_ids, topk_weights, token_lora_mapping, active_lora_ids
|
||||||
|
|
||||||
|
|
||||||
def use_fused_moe_lora_kernel(
|
def use_fused_moe_lora_kernel(
|
||||||
@@ -127,6 +130,7 @@ def use_fused_moe_lora_kernel(
|
|||||||
token_lora_mapping,
|
token_lora_mapping,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
|
lora_ids,
|
||||||
lora_a_stacked,
|
lora_a_stacked,
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -149,7 +153,6 @@ def use_fused_moe_lora_kernel(
|
|||||||
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
|
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
|
||||||
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
|
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
|
||||||
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
||||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
|
|
||||||
|
|
||||||
# call kernel
|
# call kernel
|
||||||
ops.moe_lora_align_block_size(
|
ops.moe_lora_align_block_size(
|
||||||
@@ -168,7 +171,7 @@ def use_fused_moe_lora_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": block_size,
|
||||||
"BLOCK_SIZE_N": 32,
|
"BLOCK_SIZE_N": 32,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
@@ -275,7 +278,7 @@ def test_fused_moe_lora_kernel(
|
|||||||
# the number of randomly generated sentences.
|
# the number of randomly generated sentences.
|
||||||
num_sequences = 10
|
num_sequences = 10
|
||||||
# generate data
|
# generate data
|
||||||
topk_ids, topk_weights, token_lora_mapping = sample_data(
|
topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data(
|
||||||
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,6 +321,7 @@ def test_fused_moe_lora_kernel(
|
|||||||
token_lora_mapping,
|
token_lora_mapping,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
|
lora_ids,
|
||||||
lora_a_stacked,
|
lora_a_stacked,
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -336,7 +340,7 @@ def test_fused_moe_lora_kernel(
|
|||||||
top_k_num,
|
top_k_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
|
torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
def use_fused_moe_lora_kernel_naive(
|
def use_fused_moe_lora_kernel_naive(
|
||||||
@@ -345,6 +349,7 @@ def use_fused_moe_lora_kernel_naive(
|
|||||||
token_lora_mapping,
|
token_lora_mapping,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
|
lora_ids,
|
||||||
lora_a_stacked,
|
lora_a_stacked,
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -379,7 +384,6 @@ def use_fused_moe_lora_kernel_naive(
|
|||||||
num_tokens_post_padded = None
|
num_tokens_post_padded = None
|
||||||
|
|
||||||
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
||||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
|
|
||||||
|
|
||||||
# num_active_loras is the number of active LoRAs
|
# num_active_loras is the number of active LoRAs
|
||||||
# (max_loras + 1 to include no-lora case)
|
# (max_loras + 1 to include no-lora case)
|
||||||
@@ -463,7 +467,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
|||||||
# the number of randomly generated sentences.
|
# the number of randomly generated sentences.
|
||||||
num_sequences = min(num_tokens, 4)
|
num_sequences = min(num_tokens, 4)
|
||||||
# generate data
|
# generate data
|
||||||
topk_ids, topk_weights, token_lora_mapping = sample_data(
|
topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data(
|
||||||
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -506,6 +510,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
|||||||
token_lora_mapping,
|
token_lora_mapping,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
|
lora_ids,
|
||||||
lora_a_stacked,
|
lora_a_stacked,
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -524,7 +529,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
|||||||
top_k_num,
|
top_k_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1)
|
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@@ -556,7 +561,7 @@ def test_fused_moe_lora_kernel_fully_sharded(
|
|||||||
# the number of randomly generated sentences.
|
# the number of randomly generated sentences.
|
||||||
num_sequences = 10
|
num_sequences = 10
|
||||||
# generate data
|
# generate data
|
||||||
topk_ids, topk_weights, token_lora_mapping = sample_data(
|
topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data(
|
||||||
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -576,6 +581,7 @@ def test_fused_moe_lora_kernel_fully_sharded(
|
|||||||
token_lora_mapping,
|
token_lora_mapping,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
|
lora_ids,
|
||||||
max_loras,
|
max_loras,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -601,6 +607,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
|
|||||||
token_lora_mapping,
|
token_lora_mapping,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
|
lora_ids,
|
||||||
max_loras,
|
max_loras,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -660,6 +667,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
|
|||||||
topk_ids = topk_ids.to(device)
|
topk_ids = topk_ids.to(device)
|
||||||
topk_weights = topk_weights.to(device)
|
topk_weights = topk_weights.to(device)
|
||||||
token_lora_mapping = token_lora_mapping.to(device)
|
token_lora_mapping = token_lora_mapping.to(device)
|
||||||
|
lora_ids = lora_ids.to(device)
|
||||||
|
|
||||||
ref_output = use_torch(
|
ref_output = use_torch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -698,6 +706,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
|
|||||||
token_lora_mapping,
|
token_lora_mapping,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
|
lora_ids,
|
||||||
[lora_a],
|
[lora_a],
|
||||||
[lora_b],
|
[lora_b],
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -714,4 +723,4 @@ def use_fused_moe_lora_kernel_tensor_parallel(
|
|||||||
else:
|
else:
|
||||||
output = tensor_model_parallel_all_reduce(output)
|
output = tensor_model_parallel_all_reduce(output)
|
||||||
|
|
||||||
torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||||
|
|||||||
Reference in New Issue
Block a user