Reduce the kernel overhead when num of active loras is smaller than max loras. Multiple cuda graphs are captured for each num of active-loras. (#32005)

Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
yugong333
2026-02-02 09:30:06 -08:00
committed by GitHub
parent 8b7346d5f1
commit ffe1fc7a28
15 changed files with 323 additions and 66 deletions

View File

@@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel(
expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1)
# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
fused_moe_lora(
output,
hidden_states,
@@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel(
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
@@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive(
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
fused_moe_lora(
output,
hidden_states,
@@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive(
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],