fixed mypy warnings for files vllm/v1/attention with TEMPORARY workaround (#31465)
Signed-off-by: Zhuohao Yang <zy242@cornell.edu> Co-authored-by: Zhuohao Yang <zy242@cornell.edu> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -167,7 +167,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
@@ -354,7 +354,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
aot_schedule = False
|
||||
|
||||
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
|
||||
if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size:
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and self.max_cudagraph_size is not None
|
||||
and num_actual_tokens <= self.max_cudagraph_size
|
||||
):
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
@@ -599,6 +603,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -697,6 +704,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return output
|
||||
else:
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window)
|
||||
if self.sliding_window is not None
|
||||
else None
|
||||
)
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
@@ -709,7 +721,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
window_size=sliding_window_size,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
@@ -764,12 +776,19 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
query = query.contiguous()
|
||||
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window) if self.sliding_window is not None else None
|
||||
)
|
||||
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||
q=query_across_dcp,
|
||||
k=key_cache,
|
||||
@@ -782,7 +801,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
window_size=sliding_window_size,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
@@ -813,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
@@ -850,6 +869,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
@@ -868,6 +891,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window) if self.sliding_window is not None else None
|
||||
)
|
||||
flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -880,7 +906,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Encoder attention is bidirectional
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
@@ -1020,7 +1046,7 @@ def cascade_attention(
|
||||
max_seqlen_k=common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
window_size=sliding_window,
|
||||
window_size=list(sliding_window),
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
@@ -1048,7 +1074,7 @@ def cascade_attention(
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window,
|
||||
window_size=list(sliding_window),
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
|
||||
@@ -113,6 +113,9 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -214,6 +217,11 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
)
|
||||
return output
|
||||
else:
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window)
|
||||
if self.sliding_window is not None
|
||||
else None
|
||||
)
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
@@ -226,7 +234,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
window_size=sliding_window_size,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
|
||||
@@ -530,11 +530,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self._decode_wrappers_cudagraph: dict[
|
||||
int, BatchDecodeWithPagedKVCacheWrapper
|
||||
] = {}
|
||||
self._decode_cudagraph_max_bs = min(
|
||||
(1 + num_spec_tokens) * max_num_reqs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
self._decode_cudagraph_max_bs = min(
|
||||
self._decode_cudagraph_max_bs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
|
||||
@@ -215,7 +215,7 @@ def physical_to_logical_mapping(
|
||||
)
|
||||
|
||||
# Only process valid blocks to avoid garbage values
|
||||
num_blocks_per_seq = cdiv(seq_lens, block_size)
|
||||
num_blocks_per_seq: torch.Tensor = cdiv(seq_lens, block_size)
|
||||
mask = (
|
||||
torch.arange(max_num_blocks, device=device)[None, :]
|
||||
< num_blocks_per_seq[:, None]
|
||||
|
||||
@@ -75,8 +75,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
if self.speculative_config:
|
||||
self.num_spec = self.speculative_config.num_speculative_tokens
|
||||
assert self.speculative_config.num_speculative_tokens is not None
|
||||
self.num_spec: int = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
@@ -85,10 +87,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
|
||||
self.decode_cudagraph_max_bs = (
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
|
||||
)
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.decode_cudagraph_max_bs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
|
||||
@@ -123,10 +123,11 @@ class Mamba2AttentionMetadataBuilder(
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
assert self.chunk_size is not None, (
|
||||
chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
assert chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models"
|
||||
)
|
||||
self.chunk_size: int = chunk_size
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
|
||||
@@ -69,10 +69,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.decode_cudagraph_max_bs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
@@ -150,9 +152,13 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
|
||||
block_idx_last_computed_token = torch.clamp(
|
||||
block_idx_last_computed_token, min=0
|
||||
)
|
||||
# -1 in the case we have a padded request (0 seq-len)
|
||||
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
|
||||
block_idx_last_scheduled_token = torch.clamp(
|
||||
block_idx_last_scheduled_token, min=0
|
||||
)
|
||||
|
||||
return (
|
||||
block_idx_last_computed_token,
|
||||
|
||||
@@ -62,7 +62,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
|
||||
k,
|
||||
v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
# Transpose the LSE if Triton MHA is used:
|
||||
|
||||
@@ -202,6 +202,7 @@ from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
@@ -251,13 +252,15 @@ class QueryLenSupport(Enum):
|
||||
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
flash_attn_varlen_func,
|
||||
)
|
||||
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
# For rocm use upstream flash attention
|
||||
if current_platform.is_rocm():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
is_vllm_fa = False
|
||||
|
||||
try:
|
||||
@@ -386,7 +389,7 @@ D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonMetadata(Generic[D]):
|
||||
class MLACommonMetadata(AttentionMetadata, Generic[D]):
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
@@ -434,7 +437,7 @@ class MLACommonMetadata(Generic[D]):
|
||||
|
||||
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
A = TypeVar("A")
|
||||
A = TypeVar("A", bound=AttentionMetadata)
|
||||
|
||||
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
@@ -617,7 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
|
||||
|
||||
self._global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
|
||||
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) # type: ignore[type-abstract]
|
||||
)
|
||||
|
||||
if self._use_trtllm_ragged_prefill:
|
||||
@@ -874,7 +877,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
)
|
||||
# Note(qcs): The max local context lengths
|
||||
# padded to `dcp_local_block_size`.
|
||||
padded_local_context_lens_cpu = (
|
||||
padded_local_context_lens_cpu: torch.Tensor = (
|
||||
cdiv(
|
||||
context_lens_cpu,
|
||||
self.dcp_virtual_block_size,
|
||||
@@ -1171,7 +1174,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
@@ -1327,12 +1332,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
device_capability = current_platform.get_device_capability()
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9
|
||||
and device_capability is not None
|
||||
and device_capability[0] == 9
|
||||
)
|
||||
|
||||
self.dcp_world_size: int | None = None
|
||||
self.dcp_world_size: int = -1
|
||||
|
||||
self.chunked_prefill_workspace_size = (
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
@@ -1583,7 +1590,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
@@ -1875,7 +1884,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
) -> None:
|
||||
# TODO (zyongye): Prefill function here
|
||||
assert attn_metadata.prefill is not None
|
||||
assert self.dcp_world_size is not None
|
||||
assert self.dcp_world_size != -1
|
||||
|
||||
has_context = attn_metadata.prefill.chunked_context is not None
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
@@ -1975,7 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
if self.dcp_world_size is None:
|
||||
if self.dcp_world_size == -1:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
||||
@@ -33,7 +33,10 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -181,7 +184,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and self.max_cudagraph_size is not None
|
||||
and num_decode_tokens <= self.max_cudagraph_size
|
||||
):
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
@@ -124,7 +125,7 @@ class FlashMLASparseBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
class FlashMLASparseMetadata(AttentionMetadata):
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
@@ -718,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
@@ -980,6 +981,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
@@ -236,7 +236,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -251,6 +251,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
assert attn_metadata.decode.max_qo_len is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
@@ -43,7 +43,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
return "ROCM_AITER_MLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
||||
def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]:
|
||||
return ROCMAiterMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
@@ -74,7 +74,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadata:
|
||||
class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
@@ -223,7 +223,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
def _forward_bf16_kv(
|
||||
@@ -294,6 +294,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
|
||||
@@ -155,7 +155,9 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
|
||||
spec_token_tree: str | None = None
|
||||
if spec := spec_config:
|
||||
spec_token_tree = spec.speculative_token_tree
|
||||
tree_choices: list[tuple[int, ...]] = (
|
||||
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
|
||||
)
|
||||
|
||||
@@ -469,6 +469,7 @@ def get_kv_cache_layout():
|
||||
# Format specified by the code.
|
||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||
|
||||
cache_layout: Literal["NHD", "HND"] | None = None
|
||||
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
||||
logger.info_once(
|
||||
@@ -524,7 +525,11 @@ def get_per_layer_parameters(
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
|
||||
layers = get_layers_from_vllm_config(
|
||||
vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
layer_names,
|
||||
)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
@@ -1125,7 +1130,7 @@ class KVSharingFastPrefillMetadata(Protocol):
|
||||
|
||||
def create_fast_prefill_custom_backend(
|
||||
prefix: str,
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
) -> type[AttentionBackend]:
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user