Reduce the kernel overhead when num of active loras is smaller than max loras. Multiple cuda graphs are captured for each num of active-loras. (#32005)
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
@@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel(
|
||||
expert_ids = expert_ids.view(max_loras, -1)
|
||||
sorted_token_ids = sorted_token_ids.view(max_loras, -1)
|
||||
|
||||
# num_active_loras is the number of active LoRAs
|
||||
# (max_loras + 1 to include no-lora case)
|
||||
num_active_loras = max_loras + 1
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
hidden_states,
|
||||
@@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel(
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
@@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive(
|
||||
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
|
||||
|
||||
# num_active_loras is the number of active LoRAs
|
||||
# (max_loras + 1 to include no-lora case)
|
||||
num_active_loras = max_loras + 1
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
hidden_states,
|
||||
@@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive(
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
|
||||
Reference in New Issue
Block a user