From 8c3e199998cc5b1225328f2de01a7443fbb4f3cd Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Fri, 29 Aug 2025 12:16:57 -0700 Subject: [PATCH] Revert gemma3n fast prefill changes (#23897) Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_fast_prefill.py | 1 + vllm/model_executor/models/gemma3n.py | 433 +++---------------- vllm/model_executor/models/gemma3n_mm.py | 2 +- 3 files changed, 74 insertions(+), 362 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index 7bc7f44dd..6bc9b2b1d 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -64,6 +64,7 @@ def cleanup(llm: LLM, compilation_config: CompilationConfig): @fork_new_process_for_each_test @pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 0e0e191e7..ffec34087 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -23,11 +23,9 @@ from torch import nn from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig from vllm.attention import Attention -from vllm.compilation.backends import set_model_tag from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, GeluAndMul, @@ -47,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import SupportsQuant from .utils import (AutoWeightsLoader, extract_layer_index, @@ -536,178 +533,7 @@ class Gemma3nDecoderLayer(nn.Module): return corrected_predictions -# This enables torch.compile if --kv-sharing-fast-prefill passed -@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. - kv_sharing_fast_prefill) -class Gemma3nSelfDecoder(nn.Module): - """ - Includes altup embedding and self decoder layers - """ - - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layers: list[Gemma3nDecoderLayer], - layer_idx_start: int, - per_layer_model_projection: ColumnParallelLinear, - embed_scale_per_layer: torch.Tensor, - embed_tokens_per_layer: VocabParallelEmbedding, - per_layer_projection_norm: RMSNorm, - per_layer_input_scale: torch.Tensor, - altup_projections: nn.ModuleList, - eps: torch.Tensor, - embed_tokens: VocabParallelEmbedding, - embed_scale: torch.Tensor, - ): - super().__init__() - self.decoder_layers = decoder_layers - self.layer_idx_start = layer_idx_start - self.per_layer_model_projection = per_layer_model_projection - self.config = vllm_config.model_config.hf_config - self.embed_scale_per_layer = embed_scale_per_layer - self.embed_tokens_per_layer = embed_tokens_per_layer - self.per_layer_projection_norm = per_layer_projection_norm - self.per_layer_input_scale = per_layer_input_scale - self.altup_projections = altup_projections - self.eps = eps - self.embed_tokens = embed_tokens - self.embed_scale = embed_scale - - def get_per_layer_input_embeddings( - self, input_ids: torch.Tensor) -> torch.Tensor: - # Deal with the fact that vocab_size_per_layer_input < vocab_size - # which causes us to have some out of vocab tokens by setting - # those token ids to 0. This matches the HF implementation. - per_layer_inputs_mask = torch.logical_and( - input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, - torch.zeros_like(input_ids)) - return self.embed_tokens_per_layer( - per_layer_inputs_tokens) * self.embed_scale_per_layer - - def get_per_layer_inputs( - self, - hidden_states_0: torch.Tensor, - per_layer_inputs: Optional[torch.Tensor], - ) -> torch.Tensor: - per_layer_projection = self.per_layer_model_projection(hidden_states_0) - per_layer_projection = per_layer_projection.reshape( - *hidden_states_0.shape[:-1], - self.config.num_hidden_layers, - self.config.hidden_size_per_layer_input, - ) - per_layer_projection = self.per_layer_projection_norm( - per_layer_projection) - if per_layer_inputs is not None: - # Profiling run does not compute per_layer_inputs - per_layer_inputs = per_layer_projection + per_layer_inputs - per_layer_inputs *= self.per_layer_input_scale - else: - per_layer_inputs = per_layer_projection - return per_layer_inputs - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) * self.embed_scale - - def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: - # Altup embed. - hidden_states = [hidden_states_0] * self.config.altup_num_inputs - target_magnitude = torch.mean(hidden_states_0**2, dim=-1, - keepdim=True)**0.5 - for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, - dim=-1, - keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - hidden_states = torch.stack(hidden_states, dim=-1) - return hidden_states - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - per_layer_inputs: Optional[torch.Tensor] = None, - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - if inputs_embeds is not None: - hidden_states_0 = inputs_embeds - else: - hidden_states_0 = self.get_input_embeddings(input_ids) - - adjusted_per_layer_inputs = self.get_per_layer_inputs( - hidden_states_0, per_layer_inputs) - hidden_states = self.altup_embed(hidden_states_0) - - # [altnum_inputs, num_tokens, hidden_size] - hidden_states = hidden_states.permute(2, 0, 1) - - for idx, layer in enumerate(self.decoder_layers): - layer_idx = idx + self.layer_idx_start - # [altup_num_inputs, num_tokens, hidden_size] - hidden_states = layer( - positions=positions, - hidden_states=hidden_states, - per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :], - **kwargs, - ) - - # [num_tokens, hidden_size, altnum_inputs] - hidden_states = hidden_states.permute(1, 2, 0) - - return hidden_states, adjusted_per_layer_inputs - - -# This enables torch.compile if --kv-sharing-fast-prefill passed -@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. - kv_sharing_fast_prefill) -class Gemma3nCrossDecoder(nn.Module): - """ - Cross-decoder layers - """ - - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layers: list[Gemma3nDecoderLayer], - layer_idx_start: int, - ): - super().__init__() - self.decoder_layers = decoder_layers - self.layer_idx_start = layer_idx_start - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - per_layer_inputs: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - # [altnum_inputs, num_tokens, hidden_size] - hidden_states = hidden_states.permute(2, 0, 1) - for idx, layer in enumerate(self.decoder_layers): - layer_idx = idx + self.layer_idx_start - # [altup_num_inputs, num_tokens, hidden_size] - hidden_states = layer( - positions=positions, - hidden_states=hidden_states, - per_layer_input=per_layer_inputs[:, layer_idx, :], - **kwargs, - ) - # [num_tokens, hidden_size, altnum_inputs] - hidden_states = hidden_states.permute(1, 2, 0) - return hidden_states - - -# This disables torch.compile if --kv-sharing-fast-prefill passed -@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. - cache_config.kv_sharing_fast_prefill) +@support_torch_compile class Gemma3nTextModel(nn.Module, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -717,6 +543,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -786,211 +613,95 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): lambda prefix: Gemma3nDecoderLayer( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") - - self.eps = torch.tensor(torch.finfo().min) - - first_kv_shared_layer_idx = (config.num_hidden_layers - - config.num_kv_shared_layers) - # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) - with set_model_tag("self_decoder"): - self.self_decoder = Gemma3nSelfDecoder( - vllm_config=vllm_config, - prefix=f"{prefix}.self_decoder", - decoder_layers=self.layers[:first_kv_shared_layer_idx], - layer_idx_start=0, - per_layer_model_projection=self.per_layer_model_projection, - embed_scale_per_layer=self.embed_scale_per_layer, - embed_tokens_per_layer=self.embed_tokens_per_layer, - per_layer_projection_norm=self.per_layer_projection_norm, - per_layer_input_scale=self.per_layer_input_scale, - altup_projections=self.altup_projections, - eps=self.eps, - embed_tokens=self.embed_tokens, - embed_scale=self.embed_scale, - ) - # Layer idx 20-30 are cross-decoder layers in YOCO - with set_model_tag("cross_decoder"): - self.cross_decoder = Gemma3nCrossDecoder( - vllm_config=vllm_config, - prefix=f"{prefix}.cross_decoder", - decoder_layers=self.layers[first_kv_shared_layer_idx:], - layer_idx_start=first_kv_shared_layer_idx, - ) - self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, ) - - self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill - - if self.fast_prefill_enabled: - # Allocate static buffers for CUDAGraph - # TODO(sarckk): Extract this functionality to interface - max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - device = next(self.parameters()).device - self.positions = torch.zeros(max_num_tokens, - dtype=torch.int64, - device=device) - self.hidden_states = torch.zeros( - (max_num_tokens, config.hidden_size, - self.config.altup_num_inputs), - dtype=self.embed_tokens.weight.dtype, - device=device, - ) - self.per_layer_inputs = torch.zeros( - (max_num_tokens, self.config.num_hidden_layers, - self.config.hidden_size_per_layer_input), - dtype=self.embed_tokens.weight.dtype, - device=device, - ) + self.eps = torch.tensor(torch.finfo().min) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.self_decoder.get_input_embeddings(input_ids) + return self.embed_tokens(input_ids) * self.embed_scale - def fast_prefill_forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - per_layer_inputs: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - logits_indices_padded, num_logits_indices = None, None - attn_metadata = get_forward_context().attn_metadata - - # attn_metadata is None during dummy runs - if (self.fast_prefill_enabled and attn_metadata is not None): - assert isinstance(attn_metadata, dict) - # Last layer is a KV sharing layer - layer_attn_metadata = attn_metadata[ - self.layers[-1].self_attn.attn.layer_name] - if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)): - logits_indices_padded = ( - layer_attn_metadata.logits_indices_padded) - num_logits_indices = layer_attn_metadata.num_logits_indices - - # Copy inputs for cudagraph - batch_size = positions.size(0) - self.positions[:batch_size].copy_(positions) - self_decoder_hidden_states, per_layer_inputs_adjusted = \ - self.self_decoder( - input_ids=input_ids, - positions=self.positions[:batch_size], - inputs_embeds=inputs_embeds, - per_layer_inputs=per_layer_inputs, - **kwargs, - ) - - if logits_indices_padded is None: - logits_indices_padded = torch.arange( - positions.size(0), - dtype=positions.dtype, - device=positions.device, - ) - - # NOTE(sarckk): There is currently a bug caused by - # vLLM converting output of last piecewise CUDA graph - # to weakref, causing memory to be prematurely freed - # when there are multiple compilation units - # Keep .clone() until fix in - # https://github.com/vllm-project/vllm/pull/22282 - hidden_states = self_decoder_hidden_states.clone() - - # Copy inputs for cudagraph - num_padded_logits_indices = logits_indices_padded.size(0) - self.positions[:num_padded_logits_indices].copy_( - positions[logits_indices_padded]) - self.hidden_states[:num_padded_logits_indices].copy_( - self_decoder_hidden_states[logits_indices_padded]) - self.per_layer_inputs[:num_padded_logits_indices].copy_( - per_layer_inputs_adjusted[logits_indices_padded]) - cross_decoder_hidden_states = self.cross_decoder( - positions=self.positions[:num_padded_logits_indices], - hidden_states=self.hidden_states[:num_padded_logits_indices], - per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices], - **kwargs, - ) - - if num_logits_indices is not None: - assert num_logits_indices > 0 - # Merge cross-decoder and self-decoder hidden states - hidden_states[logits_indices_padded[:num_logits_indices]] = ( - cross_decoder_hidden_states[:num_logits_indices]) - else: - hidden_states = cross_decoder_hidden_states - - return hidden_states - - def normal_forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - per_layer_inputs: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - hidden_states, per_layer_inputs = self.self_decoder( - input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - per_layer_inputs=per_layer_inputs, - **kwargs, - ) - hidden_states = self.cross_decoder( - positions=positions, - hidden_states=hidden_states, - per_layer_inputs=per_layer_inputs, - **kwargs, - ) - return hidden_states - - def altup_unembed( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - # Altup unembed. - target_magnitude = torch.mean(hidden_states[..., 0]**2, - dim=-1, - keepdim=True)**0.5 - for i in range(1, self.config.altup_num_inputs): - hidden_states[..., i] = self.altup_unembed_projections[i - 1]( - hidden_states[..., i]) - new_magnitude = torch.mean(hidden_states[..., i]**2, - dim=-1, - keepdim=True)**0.5 - hidden_states[..., i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size] - hidden_states = torch.mean(hidden_states, dim=-1) - return hidden_states + def get_per_layer_input_embeddings( + self, input_ids: torch.Tensor) -> torch.Tensor: + # Deal with the fact that vocab_size_per_layer_input < vocab_size + # which causes us to have some out of vocab tokens by setting + # those token ids to 0. This matches the HF implementation. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, + torch.zeros_like(input_ids)) + return self.embed_tokens_per_layer( + per_layer_inputs_tokens) * self.embed_scale_per_layer def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - per_layer_inputs: Optional[torch.Tensor] = None, + per_layer_inputs: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - if self.fast_prefill_enabled: - hidden_states = self.fast_prefill_forward( - input_ids, - positions, - inputs_embeds, - per_layer_inputs, - **kwargs, - ) + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds else: - hidden_states = self.normal_forward( - input_ids, - positions, - inputs_embeds, - per_layer_inputs, + hidden_states_0 = self.get_input_embeddings(input_ids) + + per_layer_projection = self.per_layer_model_projection(hidden_states_0) + per_layer_projection = per_layer_projection.reshape( + *hidden_states_0.shape[:-1], + self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection) + + if per_layer_inputs is not None: + # Profiling run does not compute per_layer_inputs + per_layer_inputs = per_layer_projection + per_layer_inputs + per_layer_inputs *= self.per_layer_input_scale + else: + per_layer_inputs = per_layer_projection + + # Altup embed. + hidden_states = [hidden_states_0] * self.config.altup_num_inputs + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + hidden_states = torch.stack(hidden_states, dim=0) + + # Transformer blocks. + for layer_idx, layer in enumerate(self.layers): + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_inputs[:, layer_idx, :], **kwargs, ) - hidden_states = self.altup_unembed(hidden_states) + + # Altup unembed. + target_magnitude = torch.mean(hidden_states[0]**2, + dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_unembed_projections[i - 1]( + hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=0) + return self.norm(hidden_states) def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index aba4f98ea..d59dde156 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -620,7 +620,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # them here, as the model forward has only access to the input_embeds. if input_ids is not None: - per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings( + per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings( input_ids) per_layer_inputs = per_layer_inputs.reshape( -1, self.config.text_config.num_hidden_layers,