Add ability to replace oot ops when using lora (#37181)

Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
This commit is contained in:
Kyuyeun Kim
2026-03-16 18:04:15 -07:00
committed by GitHub
parent 6c1cfbad32
commit 0a0a1a198b
6 changed files with 16 additions and 11 deletions

View File

@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
@@ -155,9 +156,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
if type(source_layer) is ColumnParallelLinear:
if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
return True
if type(source_layer) is MergedColumnParallelLinear:
if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear):
if len(packed_modules_list) != 1:
return False
# Exclude layers with 3+ output sizes - those are handled by
@@ -606,7 +607,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA(
) -> bool:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if type(source_layer) is not MergedColumnParallelLinear:
if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear):
return False
# If packed_modules_list has 3+ items, use this class

View File

@@ -7,6 +7,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import ReplicatedLinear
from .base_linear import BaseLinearLayerWithLoRA
@@ -55,7 +56,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is ReplicatedLinear
return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear)
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]

View File

@@ -11,6 +11,7 @@ from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform
@@ -89,7 +90,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is RowParallelLinear
return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)
# The following layer is based on the tensor parallelism strategy given in

View File

@@ -7,6 +7,7 @@ import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
@@ -132,7 +133,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is VocabParallelEmbedding
return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding)
@property
def weight(self):

View File

@@ -22,10 +22,11 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
def get_oot_class_by_name(class_name: str) -> type | None:
def maybe_get_oot_by_class(class_type: type) -> type:
class_name = class_type.__name__
if class_name in op_registry_oot:
return op_registry_oot[class_name]
return None
return class_type
class PluggableLayer(nn.Module):

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: np.ndarray,
device: torch.device,
) -> torch.Tensor | None:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
if attn_backend != AttentionBackendEnum.FLASHINFER:
@@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp):
tp_size: int,
device: torch.device,
) -> torch.Tensor:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
attn_backend, cu_seqlens, hidden_size, tp_size, device
)