[V1] Enable prefill optimization for Gemma3n (#22628)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -10,12 +9,6 @@ import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.models.gemma3n_mm import (
|
||||
Gemma3nForConditionalGeneration)
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
|
||||
@@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
|
||||
SEED = 42
|
||||
|
||||
|
||||
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = super().forward(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds,
|
||||
**kwargs)
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# attn_metadata is None during dummy runs
|
||||
if (attn_metadata is not None
|
||||
and self.language_model.cache_config.kv_sharing_fast_prefill):
|
||||
assert isinstance(attn_metadata, dict) # true in V1
|
||||
# Gemma3n-E2B has 30 layers, with last 20 layers being
|
||||
# cross-decoder layers. Check attention metadata is correct
|
||||
for layer_name, metadata in attn_metadata.items():
|
||||
layer_idx = extract_layer_index(layer_name)
|
||||
if layer_idx >= 20:
|
||||
assert hasattr(metadata, 'logits_indices_padded')
|
||||
assert hasattr(metadata, 'num_logits_indices')
|
||||
else:
|
||||
assert not hasattr(metadata, 'logits_indices_padded')
|
||||
assert not hasattr(metadata, 'num_logits_indices')
|
||||
|
||||
# Last layer will be a KV sharing layer
|
||||
layer_attn_metadata = attn_metadata[
|
||||
self.language_model.model.layers[-1].self_attn.attn.layer_name]
|
||||
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
|
||||
assert logits_indices_padded is not None
|
||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||
assert num_logits_indices > 0
|
||||
# Reset hidden states to random values and
|
||||
# only set logits at logits_indices to valid values
|
||||
# Because logits_indices are the only positions that are used
|
||||
# for output token sampling, this still produces same outputs
|
||||
logits_hs = hidden_states[logits_indices_padded]
|
||||
hidden_states = torch.randn_like(hidden_states)
|
||||
gen_indices = logits_indices_padded[:num_logits_indices]
|
||||
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
"""
|
||||
@@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
|
||||
enforce_eager: bool,
|
||||
test_prompts: list[str],
|
||||
):
|
||||
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
|
||||
TestGemma3nForConditionalGeneration)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
compilation_config = CompilationConfig(
|
||||
# This allows vLLM compilation backend to handle allocating and
|
||||
|
||||
Reference in New Issue
Block a user