From 42489e43c2718674828ece00eefc0f11088e801d Mon Sep 17 00:00:00 2001 From: Bhoomit Date: Wed, 25 Feb 2026 07:30:55 -0800 Subject: [PATCH] [Misc][LoRA] Increase max vocab size limit to 258048 in logits processor (#34773) Signed-off-by: Bhoomit Vasani --- tests/lora/conftest.py | 12 ++++++------ tests/lora/test_layers.py | 27 ++++++++++++++++++++++++++- vllm/lora/layers/logits_processor.py | 4 ++-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index deb1ab92d..d0d8382ac 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -103,14 +103,14 @@ def dummy_model(default_vllm_config) -> nn.Module: ("output", ColumnParallelLinear(50, 10)), ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), + ("lm_head", ParallelLMHead(32064, 10)), + ("logits_processor", LogitsProcessor(32064)), ] ) ) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} - model.unpadded_vocab_size = 32000 + model.unpadded_vocab_size = 32064 return model @@ -136,8 +136,8 @@ def dummy_model_gate_up(default_vllm_config) -> nn.Module: ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), + ("lm_head", ParallelLMHead(32064, 10)), + ("logits_processor", LogitsProcessor(32064)), ] ) ) @@ -149,7 +149,7 @@ def dummy_model_gate_up(default_vllm_config) -> nn.Module: ], } model.embedding_modules = {"lm_head": "lm_head"} - model.unpadded_vocab_size = 32000 + model.unpadded_vocab_size = 32064 return model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 2a96529d8..c9c551143 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -353,7 +353,7 @@ def test_embeddings( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) +@pytest.mark.parametrize("vocab_size", [64000, 256512, 258048]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor( default_vllm_config, dist_init, num_loras, device, vocab_size, stage @@ -468,6 +468,31 @@ def test_lm_head_logits_processor( torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) +@torch.inference_mode() +@pytest.mark.parametrize("vocab_size", [512, 32000, 258049, 300000]) +@pytest.mark.parametrize("device", DEVICES) +def test_lm_head_logits_processor_invalid_vocab_size( + default_vllm_config, dist_init, vocab_size, device +) -> None: + """Test that LogitsProcessorWithLoRA raises ValueError for invalid vocab sizes.""" + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + + torch.set_default_device(device) + max_loras = 8 + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) + + logits_processor = LogitsProcessor(vocab_size) + lora_logits_processor = LogitsProcessorWithLoRA( + logits_processor, 1024, torch.float16, device, None + ) + + with pytest.raises(ValueError, match="vocab size must be > 32000 and <= 258048"): + lora_logits_processor.create_lora_weights(max_loras, lora_config) + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index d7b02ec96..217c46fbe 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -88,9 +88,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): model_config: PretrainedConfig | None = None, ) -> None: # TODO: Verify if this condition can be further relaxed - if 32000 < self.base_layer.vocab_size > 257024: + if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048: raise ValueError( - "When using LoRA, vocab size must be 32000 >= vocab_size <= 257024" + "When using LoRA, vocab size must be > 32000 and <= 258048" ) self.lora_a_stacked = torch.zeros( (