Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -24,7 +24,7 @@ def rejection_sampler():
|
||||
def create_logits_tensor(
|
||||
output_token_ids: list[list[int]],
|
||||
vocab_size: int = 100,
|
||||
token_idx_to_override: Optional[int] = None,
|
||||
token_idx_to_override: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Helper function to create logits tensor that
|
||||
will produce desired token ids on argmax"""
|
||||
@@ -43,18 +43,18 @@ def create_logits_tensor(
|
||||
|
||||
def create_sampling_metadata(
|
||||
all_greedy: bool,
|
||||
output_token_ids: Optional[list[list[int]]] = None,
|
||||
prompt_token_ids: Optional[torch.Tensor] = None,
|
||||
spec_token_ids: Optional[torch.Tensor] = None,
|
||||
temperature: Optional[torch.Tensor] = None,
|
||||
top_k: Optional[torch.Tensor] = None,
|
||||
top_p: Optional[torch.Tensor] = None,
|
||||
generators: Optional[dict[int, Any]] = None,
|
||||
frequency_penalties: Optional[list[float]] = None,
|
||||
presence_penalties: Optional[list[float]] = None,
|
||||
repetition_penalties: Optional[list[float]] = None,
|
||||
bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None,
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None,
|
||||
output_token_ids: list[list[int]] | None = None,
|
||||
prompt_token_ids: torch.Tensor | None = None,
|
||||
spec_token_ids: torch.Tensor | None = None,
|
||||
temperature: torch.Tensor | None = None,
|
||||
top_k: torch.Tensor | None = None,
|
||||
top_p: torch.Tensor | None = None,
|
||||
generators: dict[int, Any] | None = None,
|
||||
frequency_penalties: list[float] | None = None,
|
||||
presence_penalties: list[float] | None = None,
|
||||
repetition_penalties: list[float] | None = None,
|
||||
bad_words_token_ids: dict[int, list[list[int]]] | None = None,
|
||||
allowed_token_ids_mask: torch.Tensor | None = None,
|
||||
) -> SamplingMetadata:
|
||||
"""Create a v1 sampling metadata object with all_greedy set
|
||||
to the given value. Either all greedy or all random sampling
|
||||
|
||||
Reference in New Issue
Block a user