diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index edb533134..2e9fc6819 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -19,6 +19,7 @@ """Gemma 4 model implementation for vLLM.""" from collections.abc import Iterable +from dataclasses import replace from itertools import islice import regex as re @@ -32,6 +33,7 @@ from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.attention import Attention @@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import ( maybe_remap_kv_scale_name, ) from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import ( @@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module): return hidden_states, None -@support_torch_compile +def _run_decoder_layers( + decoder_layers: list[Gemma4DecoderLayer], + layer_idx_start: int, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, +) -> torch.Tensor: + """Run a slice of decoder layers with PLE extraction.""" + residual = None + for idx, layer in enumerate(decoder_layers): + layer_idx = idx + layer_idx_start + layer_per_input = ( + per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None + ) + hidden_states, residual = layer( + positions, + hidden_states, + residual, + per_layer_input=layer_per_input, + **kwargs, + ) + return hidden_states + + +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma4SelfDecoderLayers(nn.Module): + """Compiled wrapper: embedding + non-KV-shared layers (YOCO first half). + + Owns the embedding and PLE modules so they are inside the compiled + graph. Gemma4Model delegates embedding methods here. + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma4DecoderLayer], + layer_idx_start: int, + embed_tokens: VocabParallelEmbedding, + normalizer: torch.Tensor, + embed_tokens_per_layer: VocabParallelEmbedding | None, + embed_scale_per_layer: torch.Tensor | None, + per_layer_model_projection: ColumnParallelLinear | None, + per_layer_projection_norm: RMSNorm | None, + per_layer_input_scale: torch.Tensor | None, + per_layer_projection_scale: torch.Tensor | None, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + config = _get_text_config(vllm_config.model_config.hf_config) + self.config = config + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + self.vocab_size_per_layer_input = getattr( + config, "vocab_size_per_layer_input", config.vocab_size + ) + + # Shared references to modules owned by Gemma4Model — must be + # inside this nn.Module so torch.compile captures them. + self.embed_tokens = embed_tokens + self.normalizer = normalizer + self.embed_tokens_per_layer = embed_tokens_per_layer + self.embed_scale_per_layer = embed_scale_per_layer + self.per_layer_model_projection = per_layer_model_projection + self.per_layer_projection_norm = per_layer_projection_norm + self.per_layer_input_scale = per_layer_input_scale + self.per_layer_projection_scale = per_layer_projection_scale + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.normalizer + + def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None: + """Get per-layer embeddings from embed_tokens_per_layer. + + Returns: + Per-layer embeddings (num_tokens, num_layers, + hidden_size_per_layer_input) + """ + if self.embed_tokens_per_layer is None: + return None + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, + input_ids < self.vocab_size_per_layer_input, + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) + per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer + return per_layer_embeds.reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: torch.Tensor | None, + ) -> torch.Tensor | None: + """Project inputs_embeds and combine with per_layer_inputs. + + Steps: + 1. Project inputs_embeds: hidden_size → total_ple_dim + 2. Scale by hidden_size^{-0.5} + 3. Reshape to (num_tokens, num_layers, per_layer_dim) + 4. Normalize with per_layer_projection_norm + 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) + """ + if self.per_layer_model_projection is None: + return None + per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + if per_layer_inputs is None: + return per_layer_projection + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if inputs_embeds is not None: + hidden_states = inputs_embeds + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_inputs + ) + else: + hidden_states = self.embed_input_ids(input_ids) + per_layer_embeds = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_embeds + ) + + hidden_states = _run_decoder_layers( + self.decoder_layers, + self.layer_idx_start, + positions, + hidden_states, + per_layer_inputs, + **kwargs, + ) + return hidden_states, per_layer_inputs + + +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma4CrossDecoderLayers(nn.Module): + """Cross-decoder layers (YOCO second half, KV-shared).""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma4DecoderLayer], + layer_idx_start: int, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + return _run_decoder_layers( + self.decoder_layers, + self.layer_idx_start, + positions, + hidden_states, + per_layer_inputs, + **kwargs, + ) + + +@support_torch_compile( + enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill +) class Gemma4Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -740,6 +941,75 @@ class Gemma4Model(nn.Module): torch.tensor(config.hidden_size**0.5), persistent=False, ) + + # --- You Only Cache Once (YOCO) split for fast prefill --- + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( + config, "num_kv_shared_layers", 0 + ) + + from vllm.compilation.backends import set_model_tag + + # Layers 0..(K-1) are self-decoder layers in YOCO + with set_model_tag("self_decoder"): + self.self_decoder = Gemma4SelfDecoderLayers( + vllm_config=vllm_config, + prefix=f"{prefix}.self_decoder", + decoder_layers=self.layers[:first_kv_shared_layer_idx], + layer_idx_start=0, + embed_tokens=self.embed_tokens, + normalizer=self.normalizer, + embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None), + embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None), + per_layer_model_projection=getattr( + self, "per_layer_model_projection", None + ), + per_layer_projection_norm=getattr( + self, "per_layer_projection_norm", None + ), + per_layer_input_scale=getattr(self, "per_layer_input_scale", None), + per_layer_projection_scale=getattr( + self, "per_layer_projection_scale", None + ), + ) + # Layers K..(N-1) are cross-decoder layers in YOCO + with set_model_tag("cross_decoder"): + self.cross_decoder = Gemma4CrossDecoderLayers( + vllm_config=vllm_config, + prefix=f"{prefix}.cross_decoder", + decoder_layers=self.layers[first_kv_shared_layer_idx:], + layer_idx_start=first_kv_shared_layer_idx, + ) + + self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill + + if self.fast_prefill_enabled: + # Allocate static buffers for CUDAGraph + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + device = next(self.parameters()).device + self.positions = torch.zeros( + max_num_tokens, dtype=torch.int64, device=device + ) + self.hidden_states = torch.zeros( + (max_num_tokens, config.hidden_size), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + if ( + self.hidden_size_per_layer_input + and self.hidden_size_per_layer_input > 0 + ): + self.per_layer_inputs = torch.zeros( + ( + max_num_tokens, + config.num_hidden_layers, + self.hidden_size_per_layer_input, + ), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + else: + self.per_layer_inputs = None + # Custom factory that includes per_layer_inputs for PLE-enabled PP. # per_layer_inputs has shape (batch, num_layers, per_layer_dim), # which differs from the standard (batch, hidden_size) shape, @@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module): self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) * self.normalizer + return self.self_decoder.embed_input_ids(input_ids) - def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor: + def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None: """Get per-layer embeddings from embed_tokens_per_layer. Returns: Per-layer embeddings (num_tokens, num_layers, hidden_size_per_layer_input) """ - if self.embed_tokens_per_layer is None: - return None - - # Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may - # be smaller than the main vocab_size). - per_layer_inputs_mask = torch.logical_and( - input_ids >= 0, - input_ids < self.vocab_size_per_layer_input, - ) - per_layer_inputs_tokens = torch.where( - per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) - ) - - # Get packed per-layer embeddings: (num_tokens, total_ple_dim) - per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) - - # Apply embed_scale (sqrt of per-layer hidden dim) - per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer - - # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) - per_layer_embeds = per_layer_embeds.reshape( - *input_ids.shape, - self.config.num_hidden_layers, - self.hidden_size_per_layer_input, - ) - return per_layer_embeds + return self.self_decoder.get_per_layer_inputs(input_ids) def project_per_layer_inputs( self, inputs_embeds: torch.Tensor, per_layer_inputs: torch.Tensor | None, - ) -> torch.Tensor: + ) -> torch.Tensor | None: """Project inputs_embeds and combine with per_layer_inputs. Steps: @@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module): 4. Normalize with per_layer_projection_norm 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) """ - if self.per_layer_model_projection is None: - return None - - # Project from hidden_size to total_ple_dim - # Scaled projection: output = linear(input, weight) * scale - per_layer_projection = self.per_layer_model_projection(inputs_embeds) - per_layer_projection = per_layer_projection * self.per_layer_projection_scale - - # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) - per_layer_projection = per_layer_projection.reshape( - *inputs_embeds.shape[:-1], - self.config.num_hidden_layers, - self.hidden_size_per_layer_input, + return self.self_decoder.project_per_layer_inputs( + inputs_embeds, per_layer_inputs ) - # Normalize - per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + def fast_prefill_forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + logits_indices_padded, num_logits_indices = None, None + attn_metadata = get_forward_context().attn_metadata - if per_layer_inputs is None: - return per_layer_projection + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + layer_attn_metadata = attn_metadata[ + self.layers[-1].self_attn.attn.layer_name + ] + if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata): + logits_indices_padded = layer_attn_metadata.logits_indices_padded + num_logits_indices = layer_attn_metadata.num_logits_indices - # Combine: (projection + per_layer_inputs) * scale - return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + batch_size = positions.size(0) + self.positions[:batch_size].copy_(positions) + self_decoder_hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + if logits_indices_padded is None: + logits_indices_padded = torch.arange( + batch_size, + dtype=positions.dtype, + device=positions.device, + ) + + # NOTE: Keep .clone() until fix in + # https://github.com/vllm-project/vllm/pull/22282 + hidden_states = self_decoder_hidden_states.clone() + + num_padded = logits_indices_padded.size(0) + self.positions[:num_padded].copy_(positions[logits_indices_padded]) + self.hidden_states[:num_padded].copy_( + self_decoder_hidden_states[logits_indices_padded] + ) + if self.per_layer_inputs is not None and per_layer_inputs is not None: + self.per_layer_inputs[:num_padded].copy_( + per_layer_inputs[logits_indices_padded] + ) + + # Update batch_descriptor so the cross-decoder's piecewise + # CUDAGraphWrapper dispatches to the correct (reduced) batch size. + forward_context = get_forward_context() + orig_batch_desc = forward_context.batch_descriptor + if orig_batch_desc is not None: + forward_context.batch_descriptor = replace( + orig_batch_desc, num_tokens=num_padded + ) + + cross_per_layer = ( + self.per_layer_inputs[:num_padded] + if self.per_layer_inputs is not None + else None + ) + cross_hidden_states = self.cross_decoder( + self.positions[:num_padded], + self.hidden_states[:num_padded], + cross_per_layer, + **kwargs, + ) + + # Restore the original batch_descriptor + forward_context.batch_descriptor = orig_batch_desc + + if num_logits_indices is not None: + assert num_logits_indices > 0 + hidden_states[logits_indices_padded[:num_logits_indices]] = ( + cross_hidden_states[:num_logits_indices] + ) + else: + hidden_states = cross_hidden_states + + return hidden_states def forward( self, @@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module): per_layer_inputs: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor | IntermediateTensors: + if self.fast_prefill_enabled: + hidden_states = self.fast_prefill_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + # Normal (non-fast-prefill) path with PP support if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds