[Gemma4] Enable Fast Prefill Optimization (#38879)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
"""Gemma 4 model implementation for vLLM."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import replace
|
||||
from itertools import islice
|
||||
|
||||
import regex as re
|
||||
@@ -32,6 +33,7 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
@@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
@@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module):
|
||||
return hidden_states, None
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
def _run_decoder_layers(
|
||||
decoder_layers: list[Gemma4DecoderLayer],
|
||||
layer_idx_start: int,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Run a slice of decoder layers with PLE extraction."""
|
||||
residual = None
|
||||
for idx, layer in enumerate(decoder_layers):
|
||||
layer_idx = idx + layer_idx_start
|
||||
layer_per_input = (
|
||||
per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None
|
||||
)
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
per_layer_input=layer_per_input,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
)
|
||||
class Gemma4SelfDecoderLayers(nn.Module):
|
||||
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
|
||||
|
||||
Owns the embedding and PLE modules so they are inside the compiled
|
||||
graph. Gemma4Model delegates embedding methods here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma4DecoderLayer],
|
||||
layer_idx_start: int,
|
||||
embed_tokens: VocabParallelEmbedding,
|
||||
normalizer: torch.Tensor,
|
||||
embed_tokens_per_layer: VocabParallelEmbedding | None,
|
||||
embed_scale_per_layer: torch.Tensor | None,
|
||||
per_layer_model_projection: ColumnParallelLinear | None,
|
||||
per_layer_projection_norm: RMSNorm | None,
|
||||
per_layer_input_scale: torch.Tensor | None,
|
||||
per_layer_projection_scale: torch.Tensor | None,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
|
||||
config = _get_text_config(vllm_config.model_config.hf_config)
|
||||
self.config = config
|
||||
self.hidden_size_per_layer_input = getattr(
|
||||
config, "hidden_size_per_layer_input", 0
|
||||
)
|
||||
self.vocab_size_per_layer_input = getattr(
|
||||
config, "vocab_size_per_layer_input", config.vocab_size
|
||||
)
|
||||
|
||||
# Shared references to modules owned by Gemma4Model — must be
|
||||
# inside this nn.Module so torch.compile captures them.
|
||||
self.embed_tokens = embed_tokens
|
||||
self.normalizer = normalizer
|
||||
self.embed_tokens_per_layer = embed_tokens_per_layer
|
||||
self.embed_scale_per_layer = embed_scale_per_layer
|
||||
self.per_layer_model_projection = per_layer_model_projection
|
||||
self.per_layer_projection_norm = per_layer_projection_norm
|
||||
self.per_layer_input_scale = per_layer_input_scale
|
||||
self.per_layer_projection_scale = per_layer_projection_scale
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.normalizer
|
||||
|
||||
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Get per-layer embeddings from embed_tokens_per_layer.
|
||||
|
||||
Returns:
|
||||
Per-layer embeddings (num_tokens, num_layers,
|
||||
hidden_size_per_layer_input)
|
||||
"""
|
||||
if self.embed_tokens_per_layer is None:
|
||||
return None
|
||||
per_layer_inputs_mask = torch.logical_and(
|
||||
input_ids >= 0,
|
||||
input_ids < self.vocab_size_per_layer_input,
|
||||
)
|
||||
per_layer_inputs_tokens = torch.where(
|
||||
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
|
||||
)
|
||||
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
|
||||
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
|
||||
return per_layer_embeds.reshape(
|
||||
*input_ids.shape,
|
||||
self.config.num_hidden_layers,
|
||||
self.hidden_size_per_layer_input,
|
||||
)
|
||||
|
||||
def project_per_layer_inputs(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""Project inputs_embeds and combine with per_layer_inputs.
|
||||
|
||||
Steps:
|
||||
1. Project inputs_embeds: hidden_size → total_ple_dim
|
||||
2. Scale by hidden_size^{-0.5}
|
||||
3. Reshape to (num_tokens, num_layers, per_layer_dim)
|
||||
4. Normalize with per_layer_projection_norm
|
||||
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
|
||||
"""
|
||||
if self.per_layer_model_projection is None:
|
||||
return None
|
||||
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
|
||||
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*inputs_embeds.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
self.hidden_size_per_layer_input,
|
||||
)
|
||||
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
|
||||
if per_layer_inputs is None:
|
||||
return per_layer_projection
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
per_layer_inputs: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
per_layer_inputs = self.project_per_layer_inputs(
|
||||
hidden_states, per_layer_inputs
|
||||
)
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
per_layer_embeds = self.get_per_layer_inputs(input_ids)
|
||||
per_layer_inputs = self.project_per_layer_inputs(
|
||||
hidden_states, per_layer_embeds
|
||||
)
|
||||
|
||||
hidden_states = _run_decoder_layers(
|
||||
self.decoder_layers,
|
||||
self.layer_idx_start,
|
||||
positions,
|
||||
hidden_states,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states, per_layer_inputs
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
)
|
||||
class Gemma4CrossDecoderLayers(nn.Module):
|
||||
"""Cross-decoder layers (YOCO second half, KV-shared)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma4DecoderLayer],
|
||||
layer_idx_start: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return _run_decoder_layers(
|
||||
self.decoder_layers,
|
||||
self.layer_idx_start,
|
||||
positions,
|
||||
hidden_states,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
)
|
||||
class Gemma4Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -740,6 +941,75 @@ class Gemma4Model(nn.Module):
|
||||
torch.tensor(config.hidden_size**0.5),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
# --- You Only Cache Once (YOCO) split for fast prefill ---
|
||||
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
|
||||
config, "num_kv_shared_layers", 0
|
||||
)
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
# Layers 0..(K-1) are self-decoder layers in YOCO
|
||||
with set_model_tag("self_decoder"):
|
||||
self.self_decoder = Gemma4SelfDecoderLayers(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.self_decoder",
|
||||
decoder_layers=self.layers[:first_kv_shared_layer_idx],
|
||||
layer_idx_start=0,
|
||||
embed_tokens=self.embed_tokens,
|
||||
normalizer=self.normalizer,
|
||||
embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None),
|
||||
embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None),
|
||||
per_layer_model_projection=getattr(
|
||||
self, "per_layer_model_projection", None
|
||||
),
|
||||
per_layer_projection_norm=getattr(
|
||||
self, "per_layer_projection_norm", None
|
||||
),
|
||||
per_layer_input_scale=getattr(self, "per_layer_input_scale", None),
|
||||
per_layer_projection_scale=getattr(
|
||||
self, "per_layer_projection_scale", None
|
||||
),
|
||||
)
|
||||
# Layers K..(N-1) are cross-decoder layers in YOCO
|
||||
with set_model_tag("cross_decoder"):
|
||||
self.cross_decoder = Gemma4CrossDecoderLayers(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.cross_decoder",
|
||||
decoder_layers=self.layers[first_kv_shared_layer_idx:],
|
||||
layer_idx_start=first_kv_shared_layer_idx,
|
||||
)
|
||||
|
||||
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
|
||||
|
||||
if self.fast_prefill_enabled:
|
||||
# Allocate static buffers for CUDAGraph
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
device = next(self.parameters()).device
|
||||
self.positions = torch.zeros(
|
||||
max_num_tokens, dtype=torch.int64, device=device
|
||||
)
|
||||
self.hidden_states = torch.zeros(
|
||||
(max_num_tokens, config.hidden_size),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
if (
|
||||
self.hidden_size_per_layer_input
|
||||
and self.hidden_size_per_layer_input > 0
|
||||
):
|
||||
self.per_layer_inputs = torch.zeros(
|
||||
(
|
||||
max_num_tokens,
|
||||
config.num_hidden_layers,
|
||||
self.hidden_size_per_layer_input,
|
||||
),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.per_layer_inputs = None
|
||||
|
||||
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
|
||||
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
|
||||
# which differs from the standard (batch, hidden_size) shape,
|
||||
@@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module):
|
||||
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.normalizer
|
||||
return self.self_decoder.embed_input_ids(input_ids)
|
||||
|
||||
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Get per-layer embeddings from embed_tokens_per_layer.
|
||||
|
||||
Returns:
|
||||
Per-layer embeddings (num_tokens, num_layers,
|
||||
hidden_size_per_layer_input)
|
||||
"""
|
||||
if self.embed_tokens_per_layer is None:
|
||||
return None
|
||||
|
||||
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
|
||||
# be smaller than the main vocab_size).
|
||||
per_layer_inputs_mask = torch.logical_and(
|
||||
input_ids >= 0,
|
||||
input_ids < self.vocab_size_per_layer_input,
|
||||
)
|
||||
per_layer_inputs_tokens = torch.where(
|
||||
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
|
||||
)
|
||||
|
||||
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
|
||||
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
|
||||
|
||||
# Apply embed_scale (sqrt of per-layer hidden dim)
|
||||
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
|
||||
|
||||
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
|
||||
per_layer_embeds = per_layer_embeds.reshape(
|
||||
*input_ids.shape,
|
||||
self.config.num_hidden_layers,
|
||||
self.hidden_size_per_layer_input,
|
||||
)
|
||||
return per_layer_embeds
|
||||
return self.self_decoder.get_per_layer_inputs(input_ids)
|
||||
|
||||
def project_per_layer_inputs(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor | None:
|
||||
"""Project inputs_embeds and combine with per_layer_inputs.
|
||||
|
||||
Steps:
|
||||
@@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module):
|
||||
4. Normalize with per_layer_projection_norm
|
||||
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
|
||||
"""
|
||||
if self.per_layer_model_projection is None:
|
||||
return None
|
||||
|
||||
# Project from hidden_size to total_ple_dim
|
||||
# Scaled projection: output = linear(input, weight) * scale
|
||||
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
|
||||
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
|
||||
|
||||
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*inputs_embeds.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
self.hidden_size_per_layer_input,
|
||||
return self.self_decoder.project_per_layer_inputs(
|
||||
inputs_embeds, per_layer_inputs
|
||||
)
|
||||
|
||||
# Normalize
|
||||
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
|
||||
def fast_prefill_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
per_layer_inputs: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
logits_indices_padded, num_logits_indices = None, None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
if per_layer_inputs is None:
|
||||
return per_layer_projection
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
layer_attn_metadata = attn_metadata[
|
||||
self.layers[-1].self_attn.attn.layer_name
|
||||
]
|
||||
if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata):
|
||||
logits_indices_padded = layer_attn_metadata.logits_indices_padded
|
||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||
|
||||
# Combine: (projection + per_layer_inputs) * scale
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
|
||||
batch_size = positions.size(0)
|
||||
self.positions[:batch_size].copy_(positions)
|
||||
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if logits_indices_padded is None:
|
||||
logits_indices_padded = torch.arange(
|
||||
batch_size,
|
||||
dtype=positions.dtype,
|
||||
device=positions.device,
|
||||
)
|
||||
|
||||
# NOTE: Keep .clone() until fix in
|
||||
# https://github.com/vllm-project/vllm/pull/22282
|
||||
hidden_states = self_decoder_hidden_states.clone()
|
||||
|
||||
num_padded = logits_indices_padded.size(0)
|
||||
self.positions[:num_padded].copy_(positions[logits_indices_padded])
|
||||
self.hidden_states[:num_padded].copy_(
|
||||
self_decoder_hidden_states[logits_indices_padded]
|
||||
)
|
||||
if self.per_layer_inputs is not None and per_layer_inputs is not None:
|
||||
self.per_layer_inputs[:num_padded].copy_(
|
||||
per_layer_inputs[logits_indices_padded]
|
||||
)
|
||||
|
||||
# Update batch_descriptor so the cross-decoder's piecewise
|
||||
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
|
||||
forward_context = get_forward_context()
|
||||
orig_batch_desc = forward_context.batch_descriptor
|
||||
if orig_batch_desc is not None:
|
||||
forward_context.batch_descriptor = replace(
|
||||
orig_batch_desc, num_tokens=num_padded
|
||||
)
|
||||
|
||||
cross_per_layer = (
|
||||
self.per_layer_inputs[:num_padded]
|
||||
if self.per_layer_inputs is not None
|
||||
else None
|
||||
)
|
||||
cross_hidden_states = self.cross_decoder(
|
||||
self.positions[:num_padded],
|
||||
self.hidden_states[:num_padded],
|
||||
cross_per_layer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Restore the original batch_descriptor
|
||||
forward_context.batch_descriptor = orig_batch_desc
|
||||
|
||||
if num_logits_indices is not None:
|
||||
assert num_logits_indices > 0
|
||||
hidden_states[logits_indices_padded[:num_logits_indices]] = (
|
||||
cross_hidden_states[:num_logits_indices]
|
||||
)
|
||||
else:
|
||||
hidden_states = cross_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module):
|
||||
per_layer_inputs: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if self.fast_prefill_enabled:
|
||||
hidden_states = self.fast_prefill_forward(
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# Normal (non-fast-prefill) path with PP support
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
Reference in New Issue
Block a user