diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index dc3602007..b79b668f3 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -118,7 +118,10 @@ def sample_data( num_tokens, num_experts, top_k_num ) token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) - return topk_ids, topk_weights, token_lora_mapping + active_lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + lora_ids = torch.unique(token_lora_mapping, sorted=True) + active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True) + return topk_ids, topk_weights, token_lora_mapping, active_lora_ids def use_fused_moe_lora_kernel( @@ -127,6 +130,7 @@ def use_fused_moe_lora_kernel( token_lora_mapping, max_lora_rank, top_k_num, + lora_ids, lora_a_stacked, lora_b_stacked, hidden_states, @@ -149,7 +153,6 @@ def use_fused_moe_lora_kernel( expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) - lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) # call kernel ops.moe_lora_align_block_size( @@ -168,7 +171,7 @@ def use_fused_moe_lora_kernel( ) config = { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": block_size, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, @@ -275,7 +278,7 @@ def test_fused_moe_lora_kernel( # the number of randomly generated sentences. num_sequences = 10 # generate data - topk_ids, topk_weights, token_lora_mapping = sample_data( + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( num_tokens, num_sequences, max_loras, num_experts, top_k_num ) @@ -318,6 +321,7 @@ def test_fused_moe_lora_kernel( token_lora_mapping, max_lora_rank, top_k_num, + lora_ids, lora_a_stacked, lora_b_stacked, hidden_states, @@ -336,7 +340,7 @@ def test_fused_moe_lora_kernel( top_k_num, ) - torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2) def use_fused_moe_lora_kernel_naive( @@ -345,6 +349,7 @@ def use_fused_moe_lora_kernel_naive( token_lora_mapping, max_lora_rank, top_k_num, + lora_ids, lora_a_stacked, lora_b_stacked, hidden_states, @@ -379,7 +384,6 @@ def use_fused_moe_lora_kernel_naive( num_tokens_post_padded = None adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) - lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) # num_active_loras is the number of active LoRAs # (max_loras + 1 to include no-lora case) @@ -463,7 +467,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment( # the number of randomly generated sentences. num_sequences = min(num_tokens, 4) # generate data - topk_ids, topk_weights, token_lora_mapping = sample_data( + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( num_tokens, num_sequences, max_loras, num_experts, top_k_num ) @@ -506,6 +510,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment( token_lora_mapping, max_lora_rank, top_k_num, + lora_ids, lora_a_stacked, lora_b_stacked, hidden_states, @@ -524,7 +529,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment( top_k_num, ) - torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) @multi_gpu_test(num_gpus=2) @@ -556,7 +561,7 @@ def test_fused_moe_lora_kernel_fully_sharded( # the number of randomly generated sentences. num_sequences = 10 # generate data - topk_ids, topk_weights, token_lora_mapping = sample_data( + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( num_tokens, num_sequences, max_loras, num_experts, top_k_num ) @@ -576,6 +581,7 @@ def test_fused_moe_lora_kernel_fully_sharded( token_lora_mapping, max_lora_rank, top_k_num, + lora_ids, max_loras, num_experts, block_size, @@ -601,6 +607,7 @@ def use_fused_moe_lora_kernel_tensor_parallel( token_lora_mapping, max_lora_rank, top_k_num, + lora_ids, max_loras, num_experts, block_size, @@ -660,6 +667,7 @@ def use_fused_moe_lora_kernel_tensor_parallel( topk_ids = topk_ids.to(device) topk_weights = topk_weights.to(device) token_lora_mapping = token_lora_mapping.to(device) + lora_ids = lora_ids.to(device) ref_output = use_torch( hidden_states, @@ -698,6 +706,7 @@ def use_fused_moe_lora_kernel_tensor_parallel( token_lora_mapping, max_lora_rank, top_k_num, + lora_ids, [lora_a], [lora_b], hidden_states, @@ -714,4 +723,4 @@ def use_fused_moe_lora_kernel_tensor_parallel( else: output = tensor_model_parallel_all_reduce(output) - torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)