[V1] Enable prefill optimization for Gemma3n (#22628)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-08-28 14:54:30 -07:00
committed by GitHub
parent 7ffbf27239
commit cb293f6a79
9 changed files with 591 additions and 236 deletions

View File

@@ -4,11 +4,13 @@ import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
from dataclasses import dataclass, fields, make_dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
TypeVar)
import numpy as np
import torch
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
@@ -19,7 +21,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
@@ -65,6 +68,10 @@ class CommonAttentionMetadata:
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None
@dataclass
class UbatchSlice:
@@ -542,6 +549,69 @@ def make_local_attention_virtual_batches(
)
def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
if common_attn_metadata.max_query_len == 1:
# All requests are decode (assume 1 token for now)
# Skip computing fast prefill path
return common_attn_metadata
assert common_attn_metadata.logits_indices_padded is not None
assert common_attn_metadata.num_logits_indices is not None
logits_indices_padded = common_attn_metadata.logits_indices_padded
num_logits_indices = common_attn_metadata.num_logits_indices
# Get rid of CUDAGraph padding, if any
logits_indices = logits_indices_padded[:num_logits_indices]
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
# query_start_loc: [0, 15, 20, 28]
# seq_lens: [41, 31, 40]
# Find how many decode indices belong to each request
# request_ids: [0, 1, 1, 2]
request_ids = torch.bucketize(logits_indices,
query_start_loc[1:],
right=True)
# Figure out how many tokens are in each request
# num_decode_tokens: [1, 2, 1]
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
# Calculate new query_start_loc with tokens in generation_indices
# decode_query_start_loc: [0, 1, 3, 4]
decode_query_start_loc = torch.empty(num_reqs + 1,
device=query_start_loc.device,
dtype=query_start_loc.dtype)
decode_query_start_loc[0] = 0
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
decode_max_query_len = int(num_decode_tokens.max().item())
total_num_decode_tokens = int(num_decode_tokens.sum().item())
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=decode_query_start_loc,
query_start_loc_cpu=decode_query_start_loc.to("cpu",
non_blocking=True),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_decode_tokens,
max_query_len=decode_max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)
return common_attn_metadata
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
@@ -679,13 +749,56 @@ def subclass_attention_metadata(
return Wrapped
def make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls: Any, ) -> Any:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return subclass_attention_metadata(
name_prefix="KVSharingFastPrefill",
metadata_cls=metadata_cls,
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
)
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor
num_logits_indices: int
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls()
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata =\
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
metadata = super().build(common_prefix_len,
new_common_attn_metadata, fast_build)
class KVSharingFastPrefillAttentionMetadata(
metadata.__class__, # type: ignore
KVSharingFastPrefillMetadata):
def __init__(self, metadata, common_attn_metadata):
# Shallow copy all fields in metadata cls
for field in fields(metadata.__class__):
setattr(self, field.name,
getattr(metadata, field.name))
# Set additional fields that will be used in model code
assert (common_attn_metadata.logits_indices_padded
is not None
and common_attn_metadata.num_logits_indices
is not None)
self.logits_indices_padded = \
common_attn_metadata.logits_indices_padded
self.num_logits_indices = \
common_attn_metadata.num_logits_indices
return KVSharingFastPrefillAttentionMetadata(
metadata, common_attn_metadata)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=FastPrefillAttentionBuilder)
return attn_backend