[Bugfix][Speculative Decoding] Fix Eagle3 quantization config issue (#25883)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
@@ -13,6 +13,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -33,7 +35,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
quant_config = self.get_quant_config(vllm_config)
|
||||
|
||||
# override qkv
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
@@ -53,6 +55,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
else:
|
||||
self._residual_norm = self._norm_after_residual
|
||||
|
||||
def get_quant_config(
|
||||
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
|
||||
"""Use drafter's quantization config instead of verifier's."""
|
||||
draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||
draft_load_config = vllm_config.load_config
|
||||
|
||||
return VllmConfig.get_quantization_config(
|
||||
draft_model_config,
|
||||
draft_load_config) if draft_model_config else None
|
||||
|
||||
def _norm_before_residual(
|
||||
self,
|
||||
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user