[V1] Enable prefill optimization for Gemma3n (#22628)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -4,11 +4,13 @@ import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||
from dataclasses import dataclass, fields, make_dataclass
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
|
||||
TypeVar)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils import cdiv
|
||||
@@ -19,7 +21,8 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
@@ -65,6 +68,10 @@ class CommonAttentionMetadata:
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# Needed by FastPrefillAttentionBuilder
|
||||
logits_indices_padded: Optional[torch.Tensor] = None
|
||||
num_logits_indices: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchSlice:
|
||||
@@ -542,6 +549,69 @@ def make_local_attention_virtual_batches(
|
||||
)
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> CommonAttentionMetadata:
|
||||
if common_attn_metadata.max_query_len == 1:
|
||||
# All requests are decode (assume 1 token for now)
|
||||
# Skip computing fast prefill path
|
||||
return common_attn_metadata
|
||||
|
||||
assert common_attn_metadata.logits_indices_padded is not None
|
||||
assert common_attn_metadata.num_logits_indices is not None
|
||||
|
||||
logits_indices_padded = common_attn_metadata.logits_indices_padded
|
||||
num_logits_indices = common_attn_metadata.num_logits_indices
|
||||
# Get rid of CUDAGraph padding, if any
|
||||
logits_indices = logits_indices_padded[:num_logits_indices]
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
# Example inputs
|
||||
# num_reqs: 3
|
||||
# generation_indices: [14, 18, 19, 27]
|
||||
# query_start_loc: [0, 15, 20, 28]
|
||||
# seq_lens: [41, 31, 40]
|
||||
|
||||
# Find how many decode indices belong to each request
|
||||
# request_ids: [0, 1, 1, 2]
|
||||
request_ids = torch.bucketize(logits_indices,
|
||||
query_start_loc[1:],
|
||||
right=True)
|
||||
|
||||
# Figure out how many tokens are in each request
|
||||
# num_decode_tokens: [1, 2, 1]
|
||||
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
|
||||
|
||||
# Calculate new query_start_loc with tokens in generation_indices
|
||||
# decode_query_start_loc: [0, 1, 3, 4]
|
||||
decode_query_start_loc = torch.empty(num_reqs + 1,
|
||||
device=query_start_loc.device,
|
||||
dtype=query_start_loc.dtype)
|
||||
|
||||
decode_query_start_loc[0] = 0
|
||||
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
|
||||
decode_max_query_len = int(num_decode_tokens.max().item())
|
||||
total_num_decode_tokens = int(num_decode_tokens.sum().item())
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=decode_query_start_loc,
|
||||
query_start_loc_cpu=decode_query_start_loc.to("cpu",
|
||||
non_blocking=True),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_decode_tokens,
|
||||
max_query_len=decode_max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
return common_attn_metadata
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]]
|
||||
@@ -679,13 +749,56 @@ def subclass_attention_metadata(
|
||||
return Wrapped
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls: Any, ) -> Any:
|
||||
"""
|
||||
Return a new subclass of `metadata_cls` for fast prefill
|
||||
"""
|
||||
return subclass_attention_metadata(
|
||||
name_prefix="KVSharingFastPrefill",
|
||||
metadata_cls=metadata_cls,
|
||||
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
|
||||
)
|
||||
@runtime_checkable
|
||||
class KVSharingFastPrefillMetadata(Protocol):
|
||||
logits_indices_padded: torch.Tensor
|
||||
num_logits_indices: int
|
||||
|
||||
|
||||
def create_fast_prefill_custom_backend(
|
||||
prefix: str,
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_common_attn_metadata =\
|
||||
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
|
||||
metadata = super().build(common_prefix_len,
|
||||
new_common_attn_metadata, fast_build)
|
||||
|
||||
class KVSharingFastPrefillAttentionMetadata(
|
||||
metadata.__class__, # type: ignore
|
||||
KVSharingFastPrefillMetadata):
|
||||
|
||||
def __init__(self, metadata, common_attn_metadata):
|
||||
# Shallow copy all fields in metadata cls
|
||||
for field in fields(metadata.__class__):
|
||||
setattr(self, field.name,
|
||||
getattr(metadata, field.name))
|
||||
|
||||
# Set additional fields that will be used in model code
|
||||
assert (common_attn_metadata.logits_indices_padded
|
||||
is not None
|
||||
and common_attn_metadata.num_logits_indices
|
||||
is not None)
|
||||
self.logits_indices_padded = \
|
||||
common_attn_metadata.logits_indices_padded
|
||||
self.num_logits_indices = \
|
||||
common_attn_metadata.num_logits_indices
|
||||
|
||||
return KVSharingFastPrefillAttentionMetadata(
|
||||
metadata, common_attn_metadata)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=FastPrefillAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
Reference in New Issue
Block a user