[Quantization] [Eagle] Add complete quantization support to the draft model in Eagle (#28435)

Signed-off-by: Shreyas Kulkarni <shreyas.gp269@gmail.com>
This commit is contained in:
Shreyas Kulkarni
2025-11-17 18:01:34 -05:00
committed by GitHub
parent 7765e5ba75
commit 95ae50b7d1
4 changed files with 282 additions and 29 deletions

View File

@@ -11,19 +11,27 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
from .utils import (
AutoWeightsLoader,
get_draft_quant_config,
maybe_prefix,
process_eagle_weight,
)
logger = init_logger(__name__)
@@ -66,14 +74,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config
return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)
return get_draft_quant_config(vllm_config)
def _norm_before_residual(
self, hidden_states: torch.Tensor
@@ -140,6 +141,9 @@ class LlamaModel(nn.Module):
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
@@ -160,13 +164,19 @@ class LlamaModel(nn.Module):
]
)
if hasattr(self.config, "target_hidden_size"):
self.fc = torch.nn.Linear(
self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
)
fc_input_size = self.config.target_hidden_size * 3
else:
self.fc = torch.nn.Linear(
self.config.hidden_size * 3, self.config.hidden_size, bias=False
)
fc_input_size = self.config.hidden_size * 3
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
@@ -211,6 +221,24 @@ class LlamaModel(nn.Module):
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue