[Core] Change LoRA embedding sharding to support loading methods (#5038)
This commit is contained in:
@@ -2,6 +2,7 @@ import random
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -32,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
from .utils import DummyLoRAManager
|
||||
@@ -427,7 +428,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
logits_processor = LogitsProcessor(
|
||||
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device,
|
||||
None)
|
||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
return linear, logits_processor, lora_logits_processor
|
||||
@@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
|
||||
torch.allclose(ref_q, actual_q)
|
||||
torch.allclose(ref_k, actual_k)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("seed", list(range(256)))
|
||||
def test_vocab_parallel_embedding_indices(tp_size, seed):
|
||||
random.seed(seed)
|
||||
vocab_size = random.randint(4000, 64000)
|
||||
added_vocab_size = random.randint(0, 1024)
|
||||
org_vocab_size = vocab_size - added_vocab_size
|
||||
last_org_vocab_end_index = 0
|
||||
last_added_vocab_end_index = org_vocab_size
|
||||
computed_vocab_size = 0
|
||||
computed_org_vocab_size = 0
|
||||
computed_added_vocab_size = 0
|
||||
vocab_size_padded = -1
|
||||
|
||||
all_org_tokens = []
|
||||
all_added_tokens = []
|
||||
token_ids = []
|
||||
|
||||
for tp_rank in range(tp_size):
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
|
||||
return_value=tp_rank
|
||||
), patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
|
||||
return_value=tp_size):
|
||||
vocab_embedding = VocabParallelEmbedding(
|
||||
vocab_size, 1, org_num_embeddings=org_vocab_size)
|
||||
vocab_size_padded = vocab_embedding.num_embeddings_padded
|
||||
shard_indices = vocab_embedding.shard_indices
|
||||
# Assert that the ranges are contiguous
|
||||
assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
|
||||
assert (shard_indices.added_vocab_start_index ==
|
||||
last_added_vocab_end_index)
|
||||
|
||||
# Ensure that we are not exceeding the vocab size
|
||||
computed_vocab_size += shard_indices.num_elements_padded
|
||||
computed_org_vocab_size += shard_indices.num_org_elements
|
||||
computed_added_vocab_size += shard_indices.num_added_elements
|
||||
|
||||
# Ensure that the ranges are not overlapping
|
||||
all_org_tokens.extend(
|
||||
range(shard_indices.org_vocab_start_index,
|
||||
shard_indices.org_vocab_end_index))
|
||||
all_added_tokens.extend(
|
||||
range(shard_indices.added_vocab_start_index,
|
||||
shard_indices.added_vocab_end_index))
|
||||
|
||||
token_ids.extend(
|
||||
range(shard_indices.org_vocab_start_index,
|
||||
shard_indices.org_vocab_end_index))
|
||||
token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
|
||||
shard_indices.num_org_elements))
|
||||
token_ids.extend(
|
||||
range(shard_indices.added_vocab_start_index,
|
||||
shard_indices.added_vocab_end_index))
|
||||
token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
|
||||
shard_indices.num_added_elements))
|
||||
|
||||
last_org_vocab_end_index = shard_indices.org_vocab_end_index
|
||||
last_added_vocab_end_index = shard_indices.added_vocab_end_index
|
||||
|
||||
assert computed_vocab_size == vocab_size_padded
|
||||
assert computed_org_vocab_size == org_vocab_size
|
||||
assert computed_added_vocab_size == added_vocab_size
|
||||
|
||||
# Ensure that the ranges are not overlapping
|
||||
assert len(all_org_tokens) == len(set(all_org_tokens))
|
||||
assert len(all_added_tokens) == len(set(all_added_tokens))
|
||||
assert not set(all_org_tokens).intersection(set(all_added_tokens))
|
||||
|
||||
token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
|
||||
reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
|
||||
assert reindex_mapping is not None or tp_size == 1
|
||||
if reindex_mapping is not None:
|
||||
reindexed_token_ids = token_ids_tensor[reindex_mapping]
|
||||
expected = torch.tensor(list(range(0, vocab_size)))
|
||||
assert reindexed_token_ids[:vocab_size].equal(expected)
|
||||
assert torch.all(reindexed_token_ids[vocab_size:] == -1)
|
||||
|
||||
|
||||
def test_get_masked_input_and_mask():
|
||||
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
||||
|
||||
# base tp 1 case, no padding
|
||||
modified_x, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=0)
|
||||
assert torch.equal(x, modified_x)
|
||||
|
||||
# tp 2 case, no padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=0)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
|
||||
|
||||
# tp 4 case, no padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=2,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=9,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=2,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=9,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_2, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=6,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=11,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_3, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=6,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=11,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=0)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_2,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
|
||||
assert torch.equal(modified_x_rank_3,
|
||||
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
|
||||
|
||||
# base tp 1 case, with padding
|
||||
modified_x, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=2)
|
||||
assert torch.equal(modified_x,
|
||||
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
|
||||
|
||||
# tp 2 case, with padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=2)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
|
||||
|
||||
# tp 4 case, with padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=2,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=9,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=2,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=9,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_2, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=6,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=11,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_3, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=6,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=11,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=2)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_2,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
|
||||
assert torch.equal(modified_x_rank_3,
|
||||
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
|
||||
|
||||
Reference in New Issue
Block a user