diff --git a/tests/model_executor/test_eagle_quantization.py b/tests/model_executor/test_eagle_quantization.py index 1203aef6a..519a48cae 100644 --- a/tests/model_executor/test_eagle_quantization.py +++ b/tests/model_executor/test_eagle_quantization.py @@ -139,6 +139,51 @@ def test_maybe_remap_kv_scale_name(): assert remapped in params_dict or remapped == name or remapped is None +def test_eagle3_lm_head_receives_quant_config(): + """Eagle3LlamaForCausalLM must pass quant_config to ParallelLMHead. + + Without quant_config, quantized lm_head weights (e.g. INT8 per-channel) + in Eagle3 drafter checkpoints fail to load because ParallelLMHead doesn't + expect weight_packed tensors. + """ + from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM + + mock_quant_config = Mock() + + mock_hf_config = Mock() + mock_hf_config.draft_vocab_size = 1000 + mock_hf_config.hidden_size = 256 + mock_hf_config.vocab_size = 32000 + mock_hf_config.logit_scale = 1.0 + + mock_vllm_config = Mock() + mock_vllm_config.speculative_config.draft_model_config.hf_config = mock_hf_config + mock_vllm_config.model_config.get_num_layers.return_value = 32 + mock_vllm_config.speculative_config.parallel_drafting = False + + with ( + patch("vllm.model_executor.models.llama_eagle3.LlamaModel") as MockModel, + patch("vllm.model_executor.models.llama_eagle3.ParallelLMHead") as MockLMHead, + patch("vllm.model_executor.models.llama_eagle3.LogitsProcessor"), + patch( + "vllm.model_executor.models.llama_eagle3.get_draft_quant_config", + return_value=mock_quant_config, + ), + ): + MockModel.return_value.use_aux_hidden_state = True + + Eagle3LlamaForCausalLM(vllm_config=mock_vllm_config) + + MockLMHead.assert_called_once() + call_kwargs = MockLMHead.call_args.kwargs + assert "quant_config" in call_kwargs, ( + "ParallelLMHead must receive quant_config for quantized lm_head weights" + ) + assert call_kwargs["quant_config"] is mock_quant_config, ( + "ParallelLMHead must receive the draft model's quant_config" + ) + + def test_load_weights_kv_scale_handling(): kv_scale_param = Mock() kv_scale_param.weight_loader = Mock() diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 462d18c98..fcec4a4d8 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -294,6 +294,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): self.lm_head = ParallelLMHead( self.config.draft_vocab_size, self.config.hidden_size, + quant_config=get_draft_quant_config(vllm_config), prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4b20413ca..f331e68fb 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1390,6 +1390,8 @@ class SpecDecodeBaseProposer: ) elif ( hasattr(target_language_model, "lm_head") + and hasattr(target_language_model.lm_head, "weight") + and hasattr(self.model.lm_head, "weight") and isinstance(target_language_model.lm_head.weight, torch.Tensor) and isinstance(self.model.lm_head.weight, torch.Tensor) # TODO: Offload to CPU for comparison to avoid extra GPU memory