[Quantization] [Eagle] Add complete quantization support to the draft model in Eagle (#28435)

Signed-off-by: Shreyas Kulkarni <shreyas.gp269@gmail.com>
This commit is contained in:
Shreyas Kulkarni
2025-11-17 18:01:34 -05:00
committed by GitHub
parent 7765e5ba75
commit 95ae50b7d1
4 changed files with 282 additions and 29 deletions

View File

@@ -11,13 +11,22 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
from .utils import (
AutoWeightsLoader,
get_draft_quant_config,
maybe_prefix,
process_eagle_weight,
)
logger = init_logger(__name__)
@@ -40,14 +49,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config
return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)
return get_draft_quant_config(vllm_config)
@support_torch_compile
@@ -63,6 +65,9 @@ class LlamaModel(nn.Module):
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
@@ -80,8 +85,14 @@ class LlamaModel(nn.Module):
for i in range(self.config.num_hidden_layers)
]
)
self.fc = torch.nn.Linear(
self.config.hidden_size * 2, self.config.hidden_size, bias=False
self.fc = ReplicatedLinear(
input_size=self.config.hidden_size * 2,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -117,6 +128,24 @@ class LlamaModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@@ -11,19 +11,27 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
from .utils import (
AutoWeightsLoader,
get_draft_quant_config,
maybe_prefix,
process_eagle_weight,
)
logger = init_logger(__name__)
@@ -66,14 +74,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config
return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)
return get_draft_quant_config(vllm_config)
def _norm_before_residual(
self, hidden_states: torch.Tensor
@@ -140,6 +141,9 @@ class LlamaModel(nn.Module):
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
@@ -160,13 +164,19 @@ class LlamaModel(nn.Module):
]
)
if hasattr(self.config, "target_hidden_size"):
self.fc = torch.nn.Linear(
self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
)
fc_input_size = self.config.target_hidden_size * 3
else:
self.fc = torch.nn.Linear(
self.config.hidden_size * 3, self.config.hidden_size, bias=False
)
fc_input_size = self.config.hidden_size * 3
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
@@ -211,6 +221,24 @@ class LlamaModel(nn.Module):
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@@ -18,6 +18,9 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.multimodal import NestedTensors
@@ -715,6 +718,30 @@ def maybe_prefix(prefix: str, name: str) -> str:
return name if not prefix else f"{prefix}.{name}"
def get_draft_quant_config(
vllm_config: VllmConfig,
) -> QuantizationConfig | None:
"""Get quantization config for Draft models.
Draft models should use their own quantization config instead of the verifier/target
model's config. This helper retrieves the draft model's quantization config.
Args:
vllm_config: The vLLM configuration object.
Returns:
The draft model's config if available, None otherwise.
"""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config
return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
"""
Extract the layer index from the module name.