Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that compute logits from hidden_stats."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
@@ -28,10 +26,10 @@ class LogitsProcessor(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
org_vocab_size: Optional[int] = None,
|
||||
org_vocab_size: int | None = None,
|
||||
scale: float = 1.0,
|
||||
logits_as_input: bool = False,
|
||||
soft_cap: Optional[float] = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@@ -53,8 +51,8 @@ class LogitsProcessor(CustomOp):
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
@@ -88,8 +86,8 @@ class LogitsProcessor(CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
embedding_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user