[Lora] Support long context lora (#4787)
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through. It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors. Follow up of https://github.com/vllm-project/vllm/pull/3095/files
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.lora.fully_sharded_layers import (
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLora,
|
||||
LogitsProcessorWithLoRA, LoRAMapping,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
@@ -22,13 +23,14 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
|
||||
convert_mapping)
|
||||
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
|
||||
PackedLoRALayerWeights, convert_mapping)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
@@ -771,3 +773,97 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 8])
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
|
||||
(6.0, 1.0)])
|
||||
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
|
||||
@pytest.mark.parametrize("is_neox_style", [True, False])
|
||||
@pytest.mark.parametrize("rotary_dim", [None, 32])
|
||||
@pytest.mark.parametrize("head_size", [32, 108])
|
||||
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||
def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
scaling_factors, max_position,
|
||||
is_neox_style, rotary_dim, head_size,
|
||||
seq_len) -> None:
|
||||
dtype = torch.float16
|
||||
seed = 0
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
long_lora_scaling_factors=scaling_factors,
|
||||
lora_dtype=dtype)
|
||||
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
base = 10000
|
||||
batch_size = 5 * num_loras
|
||||
num_heads = 7
|
||||
|
||||
# Verify lora is equivalent to linear scaling rotary embedding.
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
)
|
||||
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
|
||||
lora_rope.create_lora_weights(max_loras, lora_config)
|
||||
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, {
|
||||
"type": "linear",
|
||||
"factor": scaling_factors
|
||||
})
|
||||
linear_rope = linear_rope.to(dtype=dtype)
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
_, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=[0],
|
||||
num_inputs=batch_size,
|
||||
input_size=(1, max_position),
|
||||
input_range=(0, lora_config.lora_extra_vocab_size),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
long_lora_context = LongContextLoRAContext(list(scaling_factors),
|
||||
rotary_dim)
|
||||
|
||||
next_expected_offset = 0
|
||||
# Make sure the offset is correct.
|
||||
scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
|
||||
for scaling_factor, offset in scaling_factor_to_offset.items():
|
||||
assert offset == next_expected_offset
|
||||
next_expected_offset += scaling_factor * max_position
|
||||
|
||||
for i in range(len(scaling_factors)):
|
||||
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
|
||||
scaling_factors[i], 0)
|
||||
mapping_info = convert_mapping(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
long_lora_context=long_lora_context,
|
||||
)
|
||||
lora_rope.set_mapping(*mapping_info)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
ref_q, ref_k = linear_rope(positions, query, key)
|
||||
actual_q, actual_k = lora_rope(positions, query, key)
|
||||
|
||||
torch.allclose(ref_q, actual_q)
|
||||
torch.allclose(ref_k, actual_k)
|
||||
|
||||
Reference in New Issue
Block a user