[Experimental] Add multi-LoRA support (#1804)
Co-authored-by: Chen Shen <scv119@gmail.com> Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com> Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
|
||||
padding_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
parallel_dim = param.parallel_dim
|
||||
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
|
||||
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
|
||||
loaded_weight = loaded_weight[self.vocab_start_index:self.
|
||||
vocab_end_index]
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
@@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
embedding_dim: size of hidden state.
|
||||
bias: whether to use bias.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype)
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
|
||||
Reference in New Issue
Block a user