[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-08-21 22:05:59 -07:00
committed by GitHub
parent 5964069367
commit 17373dcd93
12 changed files with 226 additions and 214 deletions

View File

@@ -5,8 +5,7 @@ import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
TypeVar)
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
import numpy as np
import torch
@@ -543,35 +542,6 @@ def make_local_attention_virtual_batches(
)
def subclass_attention_metadata_builder(
name_prefix: str,
builder_cls: type[AttentionMetadataBuilder[M]],
build_preprocess_fn: Callable[[CommonAttentionMetadata],
CommonAttentionMetadata],
) -> type[AttentionMetadataBuilder[M]]:
"""
Return a new subclass of `builder_cls` whose .build(...) method
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
"""
name: str = name_prefix + builder_cls.__name__ # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
return builder_cls.build(self, common_prefix_len,
build_preprocess_fn(common_attn_metadata),
fast_build)
Wrapped = type(
name,
(builder_cls, ), # inherit from the original
{
"build": build,
})
return Wrapped # type: ignore
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]