[platforms] enable platform plugins (#11602)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-30 20:24:45 +08:00
committed by GitHub
parent 5dbf854553
commit b12e87f942
23 changed files with 354 additions and 175 deletions

View File

@@ -6,7 +6,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms import CpuArchEnum
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -39,6 +39,7 @@ def maybe_backend_fallback(
if guided_params.backend == "xgrammar":
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")

View File

@@ -18,7 +18,6 @@ import cloudpickle
import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
@@ -273,6 +272,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
from vllm.platforms import current_platform
current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()

View File

@@ -3,10 +3,9 @@ from typing import Any, Dict, Optional
import torch
from vllm.platforms import current_platform
def set_random_seed(seed: int) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
@@ -38,6 +37,7 @@ def set_weight_attrs(
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
setattr(weight, key, value)