Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
"""High-Performance Triton-only Attention layer."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional, Union
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
@@ -61,13 +61,13 @@ class TritonAttentionMetadata:
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
@@ -159,7 +159,7 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
@@ -216,13 +216,13 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@@ -272,9 +272,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Paged Attention impl. in Triton.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user