[Speculative Decoding] Add speculators config support (#21345)

This commit is contained in:
Dipika Sikka
2025-08-01 08:25:18 -04:00
committed by GitHub
parent 87c94bc879
commit dfbc1f8880
9 changed files with 232 additions and 11 deletions

View File

@@ -51,6 +51,25 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if getattr(config, "norm_before_residual", False):
self._residual_norm = self._norm_before_residual
else:
self._residual_norm = self._norm_after_residual
def _norm_before_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.hidden_norm(hidden_states)
residual = hidden_states
return hidden_states, residual
def _norm_after_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.hidden_norm(hidden_states)
return hidden_states, residual
def forward(
self,
positions: torch.Tensor,
@@ -59,9 +78,10 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
embeds = self.input_layernorm(embeds)
hidden_states = self.hidden_norm(hidden_states)
hidden_states, residual = self._residual_norm(
hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
@@ -102,7 +122,7 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
config=self.config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
)
])