Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -50,7 +49,7 @@ def ref_masked_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
@@ -69,7 +68,7 @@ def ref_single_query_cached_kv_attention(
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
) -> None:
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = value_cache.shape[1]
|
||||
@@ -415,7 +414,7 @@ def ref_multi_query_kv_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_bias: Optional[list[torch.Tensor]],
|
||||
alibi_bias: list[torch.Tensor] | None,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
|
||||
Reference in New Issue
Block a user