Override attention metadata for fast prefill in some KV sharing setups (#21590)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -3,8 +3,8 @@
|
||||
import abc
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills(
|
||||
modified_batch = True
|
||||
|
||||
return modified_batch
|
||||
|
||||
|
||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
||||
('logits_indices_padded', Optional[torch.Tensor], None),
|
||||
('num_logits_indices', int, 0),
|
||||
]
|
||||
|
||||
|
||||
def subclass_attention_metadata(
|
||||
name_prefix: str,
|
||||
metadata_cls: Any,
|
||||
fields: list[tuple[str, Any, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
Return a new subclass of `metadata_cls` with additional fields
|
||||
"""
|
||||
name: str = name_prefix + metadata_cls.__name__ # type: ignore
|
||||
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
|
||||
return Wrapped
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls: Any, ) -> Any:
|
||||
"""
|
||||
Return a new subclass of `metadata_cls` for fast prefill
|
||||
"""
|
||||
return subclass_attention_metadata(
|
||||
name_prefix="KVSharingFastPrefill",
|
||||
metadata_cls=metadata_cls,
|
||||
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user