Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -160,8 +159,8 @@ if current_platform.is_rocm():
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[list[int]], # -1 means infinite context window
|
||||
alibi_slopes: Optional[list[float]],
|
||||
window_size: list[int] | None, # -1 means infinite context window
|
||||
alibi_slopes: list[float] | None,
|
||||
block_table: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
@@ -209,8 +208,8 @@ if current_platform.is_rocm():
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[list[int]], # -1 means infinite context window
|
||||
alibi_slopes: Optional[list[float]],
|
||||
window_size: list[int] | None, # -1 means infinite context window
|
||||
alibi_slopes: list[float] | None,
|
||||
block_table: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
@@ -249,7 +248,7 @@ class AiterFlashAttentionMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
cu_seq_lens: Optional[torch.Tensor]
|
||||
cu_seq_lens: torch.Tensor | None
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
@@ -283,7 +282,7 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
self.aot_sliding_window: tuple[int, int] | None = None
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
@@ -361,7 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
@@ -412,12 +411,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@@ -458,9 +457,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with AiterFlashAttention.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user