Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -3,7 +3,6 @@
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import Optional, Union
import torch
@@ -160,8 +159,8 @@ if current_platform.is_rocm():
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
window_size: list[int] | None, # -1 means infinite context window
alibi_slopes: list[float] | None,
block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
@@ -209,8 +208,8 @@ if current_platform.is_rocm():
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
window_size: list[int] | None, # -1 means infinite context window
alibi_slopes: list[float] | None,
block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
@@ -249,7 +248,7 @@ class AiterFlashAttentionMetadata:
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
cu_seq_lens: Optional[torch.Tensor]
cu_seq_lens: torch.Tensor | None
# For cascade attention.
use_cascade: bool
@@ -283,7 +282,7 @@ class AiterFlashAttentionMetadataBuilder(
self.block_size = kv_cache_spec.block_size
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.aot_sliding_window: tuple[int, int] | None = None
self.total_tokens: int = 0
def build_for_cudagraph_capture(
@@ -361,7 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
@@ -412,12 +411,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
kv_sharing_target_layer_name: int | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@@ -458,9 +457,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.