[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
@@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
extract_layer_index, is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
|
||||
self.model = self._init_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
@@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
@@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
logits = self.compute_logits(hidden_states, None)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
@@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"""
|
||||
A model that uses Llama with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the LlamaModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of LlamaModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = hf_to_vllm_mapper.apply(weights)
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
self.model.load_weights(weights)
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
# LRUCacheWorkerLoRAManager instantiation requires model config.
|
||||
@property
|
||||
def config(self):
|
||||
return self.model.config
|
||||
|
||||
Reference in New Issue
Block a user