[Experimental] Add multi-LoRA support (#1804)

Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum
2024-01-24 00:26:37 +01:00
committed by GitHub
parent 9c1352eb57
commit 9b945daaf1
52 changed files with 8035 additions and 126 deletions

View File

@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
@@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding):
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None):
super().__init__(num_embeddings, embedding_dim, params_dtype)
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,