[TPU] Fix import error in tpu launch (#28758)
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
This commit is contained in:
@@ -9,20 +9,25 @@ from tpu_info import device
|
|||||||
|
|
||||||
from vllm.inputs import ProcessorInputs, PromptType
|
from vllm.inputs import ProcessorInputs, PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.cache import BlockSize
|
from vllm.config.cache import BlockSize
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
ParamsType: TypeAlias = SamplingParams | PoolingParams
|
||||||
else:
|
else:
|
||||||
BlockSize = None
|
BlockSize = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
PoolingParams = None
|
PoolingParams = None
|
||||||
AttentionBackendEnum = None
|
AttentionBackendEnum = None
|
||||||
|
ParamsType = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -203,10 +208,12 @@ class TpuPlatform(Platform):
|
|||||||
def validate_request(
|
def validate_request(
|
||||||
cls,
|
cls,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
params: SamplingParams | PoolingParams,
|
params: ParamsType,
|
||||||
processed_inputs: ProcessorInputs,
|
processed_inputs: ProcessorInputs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raises if this request is unsupported on this platform"""
|
"""Raises if this request is unsupported on this platform"""
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(params, SamplingParams)
|
isinstance(params, SamplingParams)
|
||||||
and params.sampling_type == SamplingType.RANDOM_SEED
|
and params.sampling_type == SamplingType.RANDOM_SEED
|
||||||
|
|||||||
Reference in New Issue
Block a user