[6/N][Attention] Move utils to more appropriate locations (#32215)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(
|
||||
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_backend_with_overrides(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
overrides: dict[str, Any],
|
||||
) -> type[AttentionBackend]:
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
return type(name, (attention_backend_cls,), overrides)
|
||||
|
||||
Reference in New Issue
Block a user