Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -3,7 +3,6 @@
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import Optional
import torch
@@ -36,29 +35,27 @@ class GDNAttentionMetadata:
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: Optional[torch.Tensor] = None
has_initial_state: torch.Tensor | None = None
spec_query_start_loc: Optional[torch.Tensor] = (
None # shape: [num_spec_decodes + 1,]
)
non_spec_query_start_loc: Optional[torch.Tensor] = (
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
non_spec_query_start_loc: torch.Tensor | None = (
None # shape: [batch - num_spec_decodes + 1,]
)
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: torch.Tensor | None = (
None # shape: [batch - num_spec_decodes,]
)
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
spec_token_masks: Optional[torch.Tensor] = (
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
spec_token_masks: torch.Tensor | None = (
None # shape: [num_prefill_tokens + num_decode_tokens,]
)
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
@@ -133,8 +130,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
num_accepted_tokens: torch.Tensor | None = None,
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata