Override attention metadata for fast prefill in some KV sharing setups (#21590)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-07-30 08:54:15 -07:00
committed by GitHub
parent 366f6b3a4d
commit ad510309ee
6 changed files with 287 additions and 26 deletions

View File

@@ -3,8 +3,8 @@
import abc
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
import numpy as np
import torch
@@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills(
modified_batch = True
return modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),
]
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name: str = name_prefix + metadata_cls.__name__ # type: ignore
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
return Wrapped
def make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls: Any, ) -> Any:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return subclass_attention_metadata(
name_prefix="KVSharingFastPrefill",
metadata_cls=metadata_cls,
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
)