Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -3,7 +3,7 @@
"""High-Performance Triton-only Attention layer."""
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
from typing import ClassVar
import torch
@@ -61,13 +61,13 @@ class TritonAttentionMetadata:
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
@@ -159,7 +159,7 @@ class TritonAttentionBackend(AttentionBackend):
return [torch.float16, torch.bfloat16, torch.float32]
@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
@@ -216,13 +216,13 @@ class TritonAttentionImpl(AttentionImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@@ -272,9 +272,9 @@ class TritonAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Paged Attention impl. in Triton.