Add extra punica sizes to support bigger vocabs (#4015)
This commit is contained in:
@@ -170,7 +170,8 @@ def create_random_inputs(
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_embedding_layer():
|
||||
embedding = VocabParallelEmbedding(512, 256)
|
||||
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
||||
embedding.weight.data[512:, :] = 0
|
||||
embedding.weight.data[vocab_size:, :] = 0
|
||||
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
||||
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
@@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info)
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=[0],
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
# reason="Fails when loras are in any slot other than the first.")
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_embedding_layer():
|
||||
embedding = VocabParallelEmbedding(512, 256)
|
||||
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||
embedding_data = torch.rand_like(embedding.weight.data)
|
||||
embedding.weight.data = embedding_data
|
||||
embedding.weight.data[512:, :] = 0
|
||||
embedding.weight.data[vocab_size:, :] = 0
|
||||
expanded_embedding = VocabParallelEmbedding(
|
||||
512 + lora_config.lora_extra_vocab_size * max_loras,
|
||||
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
|
||||
256,
|
||||
org_num_embeddings=512)
|
||||
expanded_embedding.weight.data[:512, :] = embedding_data
|
||||
org_num_embeddings=vocab_size)
|
||||
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
|
||||
# We need to deepcopy the embedding as it will be modified
|
||||
# in place
|
||||
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
||||
@@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
id_to_index,
|
||||
layer=lora_embedding,
|
||||
layer_weights=torch.zeros(
|
||||
(256, 512 + lora_config.lora_extra_vocab_size)),
|
||||
(256, vocab_size + lora_config.lora_extra_vocab_size)),
|
||||
generate_embeddings_tensor=256,
|
||||
)
|
||||
|
||||
@@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||
prompt_mapping):
|
||||
embedding_id = lora_id - 1
|
||||
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
|
||||
original_input_[-1] = 512
|
||||
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = 512 + embeddings_tensor_len - 1
|
||||
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
|
||||
original_input_[-1] = vocab_size
|
||||
input_[-2] = vocab_size + (
|
||||
(embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
expanded_embedding.weight[512:512 +
|
||||
expanded_embedding.weight[vocab_size:vocab_size +
|
||||
(embeddings_tensor_len *
|
||||
max_loras)] = torch.cat(embeddings_tensors)
|
||||
|
||||
@@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=[0],
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
original_inputs = deepcopy(inputs)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||
@@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def _pretest():
|
||||
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
||||
1024, 32000)
|
||||
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
||||
1024, vocab_size)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
linear.weight.data[:, 32000:] = 0
|
||||
linear.weight.data[:, vocab_size:] = 0
|
||||
logits_processor = LogitsProcessor(
|
||||
32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||
@@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
32000,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
@@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
org_vocab_size:logits_processor.org_vocab_size +
|
||||
embeddings_tensor_len] = embeddings_tensor
|
||||
|
||||
logits_processor.org_vocab_size = (32000 +
|
||||
logits_processor.org_vocab_size = (vocab_size +
|
||||
lora_config.lora_extra_vocab_size)
|
||||
expected_results = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
@@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
result = logits_processor._get_logits(hidden_states=input_,
|
||||
embedding=linear.weight,
|
||||
embedding_bias=None)
|
||||
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
||||
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
logits_processor.org_vocab_size = 32000
|
||||
logits_processor.org_vocab_size = vocab_size
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
32000,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
embedding_bias=None)[:, :32000]
|
||||
embedding_bias=None)[:, :vocab_size]
|
||||
expected_result = logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
|
||||
Reference in New Issue
Block a user