[Bugfix] Pass drafter quant_config to ParallelLMHead in Eagle3 (#37280)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This commit is contained in:
@@ -139,6 +139,51 @@ def test_maybe_remap_kv_scale_name():
|
||||
assert remapped in params_dict or remapped == name or remapped is None
|
||||
|
||||
|
||||
def test_eagle3_lm_head_receives_quant_config():
|
||||
"""Eagle3LlamaForCausalLM must pass quant_config to ParallelLMHead.
|
||||
|
||||
Without quant_config, quantized lm_head weights (e.g. INT8 per-channel)
|
||||
in Eagle3 drafter checkpoints fail to load because ParallelLMHead doesn't
|
||||
expect weight_packed tensors.
|
||||
"""
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
|
||||
mock_quant_config = Mock()
|
||||
|
||||
mock_hf_config = Mock()
|
||||
mock_hf_config.draft_vocab_size = 1000
|
||||
mock_hf_config.hidden_size = 256
|
||||
mock_hf_config.vocab_size = 32000
|
||||
mock_hf_config.logit_scale = 1.0
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.speculative_config.draft_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config.model_config.get_num_layers.return_value = 32
|
||||
mock_vllm_config.speculative_config.parallel_drafting = False
|
||||
|
||||
with (
|
||||
patch("vllm.model_executor.models.llama_eagle3.LlamaModel") as MockModel,
|
||||
patch("vllm.model_executor.models.llama_eagle3.ParallelLMHead") as MockLMHead,
|
||||
patch("vllm.model_executor.models.llama_eagle3.LogitsProcessor"),
|
||||
patch(
|
||||
"vllm.model_executor.models.llama_eagle3.get_draft_quant_config",
|
||||
return_value=mock_quant_config,
|
||||
),
|
||||
):
|
||||
MockModel.return_value.use_aux_hidden_state = True
|
||||
|
||||
Eagle3LlamaForCausalLM(vllm_config=mock_vllm_config)
|
||||
|
||||
MockLMHead.assert_called_once()
|
||||
call_kwargs = MockLMHead.call_args.kwargs
|
||||
assert "quant_config" in call_kwargs, (
|
||||
"ParallelLMHead must receive quant_config for quantized lm_head weights"
|
||||
)
|
||||
assert call_kwargs["quant_config"] is mock_quant_config, (
|
||||
"ParallelLMHead must receive the draft model's quant_config"
|
||||
)
|
||||
|
||||
|
||||
def test_load_weights_kv_scale_handling():
|
||||
kv_scale_param = Mock()
|
||||
kv_scale_param.weight_loader = Mock()
|
||||
|
||||
@@ -294,6 +294,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.draft_vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=get_draft_quant_config(vllm_config),
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
|
||||
@@ -1390,6 +1390,8 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
elif (
|
||||
hasattr(target_language_model, "lm_head")
|
||||
and hasattr(target_language_model.lm_head, "weight")
|
||||
and hasattr(self.model.lm_head, "weight")
|
||||
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
|
||||
and isinstance(self.model.lm_head.weight, torch.Tensor)
|
||||
# TODO: Offload to CPU for comparison to avoid extra GPU memory
|
||||
|
||||
Reference in New Issue
Block a user