[Model] Add base class for LoRA-supported models (#5018)

This commit is contained in:
Cyrus Leung
2024-06-27 16:03:04 +08:00
committed by GitHub
parent d12af207d2
commit 96354d6a29
20 changed files with 270 additions and 75 deletions

View File

@@ -39,7 +39,7 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
@@ -59,11 +59,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@@ -131,7 +133,7 @@ class PhiAttention(nn.Module):
class PhiMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@@ -160,7 +162,7 @@ class PhiMLP(nn.Module):
class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@@ -192,7 +194,7 @@ class PhiLayer(nn.Module):
class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@@ -229,7 +231,9 @@ class PhiModel(nn.Module):
return hidden_states
class PhiForCausalLM(nn.Module):
class PhiForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
del lora_config # Unused.
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = PhiModel(config, cache_config, quant_config)