[Lora] Support long context lora (#4787)
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through. It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors. Follow up of https://github.com/vllm-project/vllm/pull/3095/files
This commit is contained in:
@@ -21,6 +21,17 @@ from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
||||
LONG_LORA_INFOS = [{
|
||||
"lora_id": 1,
|
||||
"context_length": "16k",
|
||||
}, {
|
||||
"lora_id": 2,
|
||||
"context_length": "16k",
|
||||
}, {
|
||||
"lora_id": 3,
|
||||
"context_length": "32k",
|
||||
}]
|
||||
|
||||
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
@@ -154,6 +165,45 @@ def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_lora_files_16k_1():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_lora_files_16k_2():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_lora_files_32k():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")
|
||||
|
||||
|
||||
# SANG-TODO Download long lora files.
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_infos(long_context_lora_files_16k_1,
|
||||
long_context_lora_files_16k_2,
|
||||
long_context_lora_files_32k):
|
||||
cleanup()
|
||||
infos = {}
|
||||
for lora_checkpoint_info in LONG_LORA_INFOS:
|
||||
lora_id = lora_checkpoint_info["lora_id"]
|
||||
if lora_id == 1:
|
||||
lora = long_context_lora_files_16k_1
|
||||
elif lora_id == 2:
|
||||
lora = long_context_lora_files_16k_2
|
||||
elif lora_id == 3:
|
||||
lora = long_context_lora_files_32k
|
||||
else:
|
||||
raise AssertionError("Unknown lora id")
|
||||
infos[lora_id] = {
|
||||
"context_length": lora_checkpoint_info["context_length"],
|
||||
"lora": lora,
|
||||
}
|
||||
return infos
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
|
||||
Reference in New Issue
Block a user