[V1] Enable prefill optimization for Gemma3n (#22628)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-08-28 14:54:30 -07:00
committed by GitHub
parent 7ffbf27239
commit cb293f6a79
9 changed files with 591 additions and 236 deletions

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Optional, Union
import pytest
import torch
@@ -10,12 +9,6 @@ import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n_mm import (
Gemma3nForConditionalGeneration)
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors
from ...utils import fork_new_process_for_each_test
@@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
SEED = 42
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds,
**kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.language_model.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')
# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.language_model.model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
return hidden_states
@pytest.fixture
def test_prompts():
"""
@@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and