[Gemma4] Enable Fast Prefill Optimization (#38879)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-04-06 11:19:39 -04:00
committed by GitHub
parent e69a265135
commit 47e605092b

View File

@@ -19,6 +19,7 @@
"""Gemma 4 model implementation for vLLM."""
from collections.abc import Iterable
from dataclasses import replace
from itertools import islice
import regex as re
@@ -32,6 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
@@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (
@@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module):
return hidden_states, None
@support_torch_compile
def _run_decoder_layers(
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""Run a slice of decoder layers with PLE extraction."""
residual = None
for idx, layer in enumerate(decoder_layers):
layer_idx = idx + layer_idx_start
layer_per_input = (
per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None
)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
return hidden_states
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4SelfDecoderLayers(nn.Module):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
embed_tokens: VocabParallelEmbedding,
normalizer: torch.Tensor,
embed_tokens_per_layer: VocabParallelEmbedding | None,
embed_scale_per_layer: torch.Tensor | None,
per_layer_model_projection: ColumnParallelLinear | None,
per_layer_projection_norm: RMSNorm | None,
per_layer_input_scale: torch.Tensor | None,
per_layer_projection_scale: torch.Tensor | None,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
config = _get_text_config(vllm_config.model_config.hf_config)
self.config = config
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self.embed_tokens = embed_tokens
self.normalizer = normalizer
self.embed_tokens_per_layer = embed_tokens_per_layer
self.embed_scale_per_layer = embed_scale_per_layer
self.per_layer_model_projection = per_layer_model_projection
self.per_layer_projection_norm = per_layer_projection_norm
self.per_layer_input_scale = per_layer_input_scale
self.per_layer_projection_scale = per_layer_projection_scale
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
return per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if inputs_embeds is not None:
hidden_states = inputs_embeds
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
hidden_states = _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
return hidden_states, per_layer_inputs
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4CrossDecoderLayers(nn.Module):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
return _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -740,6 +941,75 @@ class Gemma4Model(nn.Module):
torch.tensor(config.hidden_size**0.5),
persistent=False,
)
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
config, "num_kv_shared_layers", 0
)
from vllm.compilation.backends import set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with set_model_tag("self_decoder"):
self.self_decoder = Gemma4SelfDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
embed_tokens=self.embed_tokens,
normalizer=self.normalizer,
embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None),
embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None),
per_layer_model_projection=getattr(
self, "per_layer_model_projection", None
),
per_layer_projection_norm=getattr(
self, "per_layer_projection_norm", None
),
per_layer_input_scale=getattr(self, "per_layer_input_scale", None),
per_layer_projection_scale=getattr(
self, "per_layer_projection_scale", None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma4CrossDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(
max_num_tokens, dtype=torch.int64, device=device
)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
if (
self.hidden_size_per_layer_input
and self.hidden_size_per_layer_input > 0
):
self.per_layer_inputs = torch.zeros(
(
max_num_tokens,
config.num_hidden_layers,
self.hidden_size_per_layer_input,
),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
else:
self.per_layer_inputs = None
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
@@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module):
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
return self.self_decoder.embed_input_ids(input_ids)
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds = per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
return per_layer_embeds
return self.self_decoder.get_per_layer_inputs(input_ids)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor:
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
@@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module):
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
# Project from hidden_size to total_ple_dim
# Scaled projection: output = linear(input, weight) * scale
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
return self.self_decoder.project_per_layer_inputs(
inputs_embeds, per_layer_inputs
)
# Normalize
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
def fast_prefill_forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
if per_layer_inputs is None:
return per_layer_projection
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
layer_attn_metadata = attn_metadata[
self.layers[-1].self_attn.attn.layer_name
]
if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata):
logits_indices_padded = layer_attn_metadata.logits_indices_padded
num_logits_indices = layer_attn_metadata.num_logits_indices
# Combine: (projection + per_layer_inputs) * scale
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
if logits_indices_padded is None:
logits_indices_padded = torch.arange(
batch_size,
dtype=positions.dtype,
device=positions.device,
)
# NOTE: Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
num_padded = logits_indices_padded.size(0)
self.positions[:num_padded].copy_(positions[logits_indices_padded])
self.hidden_states[:num_padded].copy_(
self_decoder_hidden_states[logits_indices_padded]
)
if self.per_layer_inputs is not None and per_layer_inputs is not None:
self.per_layer_inputs[:num_padded].copy_(
per_layer_inputs[logits_indices_padded]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context = get_forward_context()
orig_batch_desc = forward_context.batch_descriptor
if orig_batch_desc is not None:
forward_context.batch_descriptor = replace(
orig_batch_desc, num_tokens=num_padded
)
cross_per_layer = (
self.per_layer_inputs[:num_padded]
if self.per_layer_inputs is not None
else None
)
cross_hidden_states = self.cross_decoder(
self.positions[:num_padded],
self.hidden_states[:num_padded],
cross_per_layer,
**kwargs,
)
# Restore the original batch_descriptor
forward_context.batch_descriptor = orig_batch_desc
if num_logits_indices is not None:
assert num_logits_indices > 0
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_hidden_states[:num_logits_indices]
)
else:
hidden_states = cross_hidden_states
return hidden_states
def forward(
self,
@@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module):
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return hidden_states
# Normal (non-fast-prefill) path with PP support
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds