[V1] Make v1 more testable (#9888)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-11-06 12:57:35 -07:00
committed by GitHub
parent 87bd7e0515
commit d58268c56a
75 changed files with 243 additions and 165 deletions

View File

@@ -1,7 +1,7 @@
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
from tqdm import tqdm
@@ -10,6 +10,7 @@ from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
@@ -31,11 +32,6 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore
logger = init_logger(__name__)
@@ -206,10 +202,21 @@ class LLM:
pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return V1LLMEngine # type: ignore
return LLMEngine
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
@@ -394,7 +401,7 @@ class LLM:
priority=priority)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def beam_search(
self,
@@ -769,7 +776,8 @@ class LLM:
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
def start_profile(self) -> None:
self.llm_engine.start_profile()