Add ability to replace oot ops when using lora (#37181)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@@ -155,9 +156,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
if type(source_layer) is ColumnParallelLinear:
|
||||
if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
|
||||
return True
|
||||
if type(source_layer) is MergedColumnParallelLinear:
|
||||
if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear):
|
||||
if len(packed_modules_list) != 1:
|
||||
return False
|
||||
# Exclude layers with 3+ output sizes - those are handled by
|
||||
@@ -606,7 +607,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA(
|
||||
) -> bool:
|
||||
# Support MergedColumnParallelLinear with 3 or more slices
|
||||
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
|
||||
if type(source_layer) is not MergedColumnParallelLinear:
|
||||
if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear):
|
||||
return False
|
||||
|
||||
# If packed_modules_list has 3+ items, use this class
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
@@ -55,7 +56,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ReplicatedLinear
|
||||
return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear)
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -89,7 +90,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is RowParallelLinear
|
||||
return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)
|
||||
|
||||
|
||||
# The following layer is based on the tensor parallelism strategy given in
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -132,7 +133,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
||||
@@ -22,10 +22,11 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
|
||||
|
||||
def get_oot_class_by_name(class_name: str) -> type | None:
|
||||
def maybe_get_oot_by_class(class_type: type) -> type:
|
||||
class_name = class_type.__name__
|
||||
if class_name in op_registry_oot:
|
||||
return op_registry_oot[class_name]
|
||||
return None
|
||||
return class_type
|
||||
|
||||
|
||||
class PluggableLayer(nn.Module):
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
|
||||
from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
@@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp):
|
||||
cu_seqlens: np.ndarray,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor | None:
|
||||
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
|
||||
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
|
||||
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
@@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp):
|
||||
tp_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
|
||||
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
|
||||
attn_backend, cu_seqlens, hidden_size, tp_size, device
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user