Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
"""Backend for GatedDeltaNet attention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -36,29 +35,27 @@ class GDNAttentionMetadata:
|
||||
num_spec_decode_tokens: int
|
||||
num_actual_tokens: int
|
||||
|
||||
has_initial_state: Optional[torch.Tensor] = None
|
||||
has_initial_state: torch.Tensor | None = None
|
||||
|
||||
spec_query_start_loc: Optional[torch.Tensor] = (
|
||||
None # shape: [num_spec_decodes + 1,]
|
||||
)
|
||||
non_spec_query_start_loc: Optional[torch.Tensor] = (
|
||||
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
|
||||
non_spec_query_start_loc: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes + 1,]
|
||||
)
|
||||
|
||||
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
|
||||
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
|
||||
spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec]
|
||||
non_spec_state_indices_tensor: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
|
||||
spec_token_masks: Optional[torch.Tensor] = (
|
||||
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
|
||||
spec_token_masks: torch.Tensor | None = (
|
||||
None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||
)
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
|
||||
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
@@ -133,8 +130,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
Reference in New Issue
Block a user