[Misc][LoRA] Increase max vocab size limit to 258048 in logits processor (#34773)
Signed-off-by: Bhoomit Vasani <vbhoomit@amazon.com>
This commit is contained in:
@@ -103,14 +103,14 @@ def dummy_model(default_vllm_config) -> nn.Module:
|
||||
("output", ColumnParallelLinear(50, 10)),
|
||||
("outact", nn.Sigmoid()),
|
||||
# Special handling for lm_head & sampler
|
||||
("lm_head", ParallelLMHead(512, 10)),
|
||||
("logits_processor", LogitsProcessor(512)),
|
||||
("lm_head", ParallelLMHead(32064, 10)),
|
||||
("logits_processor", LogitsProcessor(32064)),
|
||||
]
|
||||
)
|
||||
)
|
||||
model.config = MagicMock()
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
model.unpadded_vocab_size = 32064
|
||||
return model
|
||||
|
||||
|
||||
@@ -136,8 +136,8 @@ def dummy_model_gate_up(default_vllm_config) -> nn.Module:
|
||||
("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
|
||||
("outact", nn.Sigmoid()),
|
||||
# Special handling for lm_head & sampler
|
||||
("lm_head", ParallelLMHead(512, 10)),
|
||||
("logits_processor", LogitsProcessor(512)),
|
||||
("lm_head", ParallelLMHead(32064, 10)),
|
||||
("logits_processor", LogitsProcessor(32064)),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -149,7 +149,7 @@ def dummy_model_gate_up(default_vllm_config) -> nn.Module:
|
||||
],
|
||||
}
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
model.unpadded_vocab_size = 32064
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -353,7 +353,7 @@ def test_embeddings(
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
|
||||
@pytest.mark.parametrize("vocab_size", [64000, 256512, 258048])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_lm_head_logits_processor(
|
||||
default_vllm_config, dist_init, num_loras, device, vocab_size, stage
|
||||
@@ -468,6 +468,31 @@ def test_lm_head_logits_processor(
|
||||
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 258049, 300000])
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lm_head_logits_processor_invalid_vocab_size(
|
||||
default_vllm_config, dist_init, vocab_size, device
|
||||
) -> None:
|
||||
"""Test that LogitsProcessorWithLoRA raises ValueError for invalid vocab sizes."""
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(
|
||||
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
|
||||
)
|
||||
|
||||
logits_processor = LogitsProcessor(vocab_size)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, torch.float16, device, None
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="vocab size must be > 32000 and <= 258048"):
|
||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
|
||||
@@ -88,9 +88,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
# TODO: Verify if this condition can be further relaxed
|
||||
if 32000 < self.base_layer.vocab_size > 257024:
|
||||
if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048:
|
||||
raise ValueError(
|
||||
"When using LoRA, vocab size must be 32000 >= vocab_size <= 257024"
|
||||
"When using LoRA, vocab size must be > 32000 and <= 258048"
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user