[Fix Bug]num_active_loras always equals to zero (#34119)
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -187,7 +187,8 @@ def use_fused_moe_lora_kernel(
|
||||
|
||||
# num_active_loras is the number of active LoRAs
|
||||
# (max_loras + 1 to include no-lora case)
|
||||
num_active_loras = max_loras + 1
|
||||
# Stored as CPU tensor to match the kernel API (torch.compile compatibility)
|
||||
num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu")
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
@@ -399,7 +400,8 @@ def use_fused_moe_lora_kernel_naive(
|
||||
|
||||
# num_active_loras is the number of active LoRAs
|
||||
# (max_loras + 1 to include no-lora case)
|
||||
num_active_loras = max_loras + 1
|
||||
# Stored as CPU tensor to match the kernel API (torch.compile compatibility)
|
||||
num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu")
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
|
||||
@@ -70,8 +70,12 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
|
||||
@pytest.mark.parametrize("specialize_active_lora", [True, False])
|
||||
def test_gpt_oss_lora(
|
||||
monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
gptoss20b_lora_files,
|
||||
mxfp4_use_marlin,
|
||||
specialize_active_lora,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
|
||||
@@ -83,6 +87,7 @@ def test_gpt_oss_lora(
|
||||
max_lora_rank=8,
|
||||
max_num_seqs=2,
|
||||
max_num_batched_tokens=2048,
|
||||
specialize_active_lora=specialize_active_lora,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
|
||||
@@ -127,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
||||
|
||||
|
||||
def _adjust_kernel_inputs(
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
):
|
||||
@@ -141,7 +141,7 @@ def _adjust_kernel_inputs(
|
||||
else:
|
||||
stride_tl = sorted_token_ids.stride(0)
|
||||
stride_el = expert_ids.stride(0)
|
||||
grid_lora_dim = num_active_loras
|
||||
grid_lora_dim = num_active_loras.item()
|
||||
return grid_lora_dim, stride_tl, stride_el
|
||||
|
||||
|
||||
@@ -444,7 +444,7 @@ def _fused_moe_lora_shrink(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
@@ -562,7 +562,7 @@ def _fused_moe_lora_expand(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
@@ -683,7 +683,7 @@ def _fused_moe_lora(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@@ -871,7 +871,7 @@ def _fused_moe_lora_fake(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@@ -921,7 +921,7 @@ def _fused_moe_lora_shrink_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
@@ -958,7 +958,7 @@ def _fused_moe_lora_expand_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
|
||||
@@ -138,7 +138,7 @@ def _lora_expand(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
@@ -235,7 +235,7 @@ def _lora_expand(
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
num_active_loras,
|
||||
num_active_loras.item(),
|
||||
)
|
||||
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
|
||||
# making PDL invalid and affecting the kernel performance.
|
||||
@@ -289,7 +289,7 @@ def _lora_expand_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
|
||||
@@ -29,9 +29,16 @@ class LoRAKernelMeta:
|
||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||
no_lora_flag_cpu: torch.Tensor
|
||||
|
||||
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
|
||||
# Stored as a Python int to avoid GPU->CPU sync during forward pass
|
||||
num_active_loras: int = 0
|
||||
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping).
|
||||
# Stored as a CPU tensor (not a Python int) so that torch.compile treats
|
||||
# it as a dynamic value rather than baking it as a constant at trace time.
|
||||
# This follows the same pattern as no_lora_flag_cpu above.
|
||||
num_active_loras_cpu: torch.Tensor
|
||||
|
||||
# Default num_active_loras value (max_loras + 1) as a CPU tensor,
|
||||
# used when specialize_active_lora is False to avoid allocating a
|
||||
# new tensor on every meta_args() call.
|
||||
default_num_active_loras_cpu: torch.Tensor
|
||||
|
||||
# Captured LoRA counts for cudagraph specialization (sorted list).
|
||||
# When specialize_active_lora is enabled, num_active_loras is rounded up
|
||||
@@ -73,6 +80,11 @@ class LoRAKernelMeta:
|
||||
|
||||
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
|
||||
|
||||
num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu")
|
||||
default_num_active_loras_cpu = torch.tensor(
|
||||
[max_loras + 1], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
||||
return LoRAKernelMeta(
|
||||
token_lora_mapping=token_lora_mapping,
|
||||
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
||||
@@ -80,6 +92,8 @@ class LoRAKernelMeta:
|
||||
num_tokens_per_lora=num_tokens_per_lora,
|
||||
lora_token_start_loc=lora_token_start_loc,
|
||||
no_lora_flag_cpu=no_lora_flag_cpu,
|
||||
num_active_loras_cpu=num_active_loras_cpu,
|
||||
default_num_active_loras_cpu=default_num_active_loras_cpu,
|
||||
captured_lora_counts=sorted(captured_lora_counts)
|
||||
if captured_lora_counts
|
||||
else [],
|
||||
@@ -90,8 +104,7 @@ class LoRAKernelMeta:
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
self.no_lora_flag_cpu.fill_(False)
|
||||
self.num_active_loras = 0
|
||||
self.captured_lora_counts = []
|
||||
self.num_active_loras_cpu.fill_(0)
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
@@ -137,14 +150,16 @@ class LoRAKernelMeta:
|
||||
num_tokens_per_lora, non_blocking=True
|
||||
)
|
||||
|
||||
self.num_active_loras = lora_ids.size(0)
|
||||
num_active_loras = lora_ids.size(0)
|
||||
|
||||
# Round up num_active_loras to match cudagraph capture keys.
|
||||
# This ensures the kernel grid dimension matches the captured graph.
|
||||
if self.captured_lora_counts and self.num_active_loras > 0:
|
||||
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
|
||||
if self.captured_lora_counts and num_active_loras > 0:
|
||||
idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
|
||||
if idx < len(self.captured_lora_counts):
|
||||
self.num_active_loras = self.captured_lora_counts[idx]
|
||||
num_active_loras = self.captured_lora_counts[idx]
|
||||
|
||||
self.num_active_loras_cpu[0] = num_active_loras
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
@@ -163,7 +178,7 @@ class LoRAKernelMeta:
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
int,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
@@ -175,7 +190,10 @@ class LoRAKernelMeta:
|
||||
token_nums (int): Number of input tokens in the current forward
|
||||
pass of the kernel.
|
||||
"""
|
||||
max_loras = self.active_lora_ids.size(0) - 1
|
||||
if specialize_active_lora:
|
||||
num_active_loras = self.num_active_loras_cpu
|
||||
else:
|
||||
num_active_loras = self.default_num_active_loras_cpu
|
||||
return (
|
||||
self.token_lora_mapping[:token_nums],
|
||||
self.token_indices_sorted_by_lora_ids[:token_nums],
|
||||
@@ -183,5 +201,5 @@ class LoRAKernelMeta:
|
||||
self.lora_token_start_loc,
|
||||
self.active_lora_ids,
|
||||
self.no_lora_flag_cpu,
|
||||
self.num_active_loras if specialize_active_lora else max_loras + 1,
|
||||
num_active_loras,
|
||||
)
|
||||
|
||||
@@ -134,7 +134,7 @@ def _lora_shrink(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -157,6 +157,9 @@ def _lora_shrink(
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the
|
||||
number of active LoRAs. Stored as a tensor (not int) so
|
||||
torch.compile treats it as dynamic rather than a constant.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
|
||||
@@ -215,7 +218,7 @@ def _lora_shrink(
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
num_active_loras,
|
||||
num_active_loras.item(),
|
||||
)
|
||||
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
|
||||
# making PDL invalid and affecting the kernel performance.
|
||||
@@ -267,7 +270,7 @@ def _lora_shrink_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@@ -5379,6 +5379,7 @@ class GPUModelRunner(
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
|
||||
dummy_run(
|
||||
num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
|
||||
Reference in New Issue
Block a user