[Bugfix][Speculative Decoding] Fix Eagle3 quantization config issue (#25883)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
@@ -248,7 +248,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
quant_config = self.get_quant_config(vllm_config)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@@ -328,6 +328,11 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
def get_quant_config(
|
||||
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
|
||||
"""Get quantization config for this layer. Override in subclasses."""
|
||||
return vllm_config.quant_config
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user