[V1][Misc] Shorten FinishReason enum and use constant strings (#12760)
This commit is contained in:
@@ -14,11 +14,17 @@ if TYPE_CHECKING:
|
|||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
# These are possible values of RequestOutput.finish_reason,
|
||||||
|
# so form part of the external API.
|
||||||
|
FINISH_REASON_STRINGS = ("stop", "length", "abort")
|
||||||
|
|
||||||
class RequestFinishedReason(enum.IntEnum):
|
|
||||||
|
class FinishReason(enum.IntEnum):
|
||||||
"""
|
"""
|
||||||
Reason a request finished - stop, length, or abort.
|
Reason a request finished - stop, length, or abort.
|
||||||
|
|
||||||
|
Int rather than Str for more compact serialization.
|
||||||
|
|
||||||
stop - a stop string was emitted
|
stop - a stop string was emitted
|
||||||
length - max_tokens was consumed, or max_model_len was reached
|
length - max_tokens was consumed, or max_model_len was reached
|
||||||
abort - aborted for another reason
|
abort - aborted for another reason
|
||||||
@@ -29,7 +35,7 @@ class RequestFinishedReason(enum.IntEnum):
|
|||||||
ABORT = 2
|
ABORT = 2
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name.lower()
|
return FINISH_REASON_STRINGS[self.value]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -62,7 +68,7 @@ class EngineCoreOutput(
|
|||||||
request_id: str
|
request_id: str
|
||||||
new_token_ids: List[int]
|
new_token_ids: List[int]
|
||||||
finished: bool
|
finished: bool
|
||||||
finish_reason: Optional[RequestFinishedReason] = None
|
finish_reason: Optional[FinishReason] = None
|
||||||
stop_reason: Union[int, str, None] = None
|
stop_reason: Union[int, str, None] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.transformers_utils.detokenizer_utils import (
|
from vllm.transformers_utils.detokenizer_utils import (
|
||||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreRequest,
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||||
RequestFinishedReason)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -19,7 +18,7 @@ class DetokenizerOutput:
|
|||||||
output_text: str
|
output_text: str
|
||||||
token_ids: List[int]
|
token_ids: List[int]
|
||||||
finished: bool
|
finished: bool
|
||||||
finish_reason: Optional[RequestFinishedReason] = None
|
finish_reason: Optional[FinishReason] = None
|
||||||
stop_reason: Union[int, str, None] = None
|
stop_reason: Union[int, str, None] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -148,7 +147,7 @@ class IncrementalDetokenizer:
|
|||||||
stop_str, truncate_to = stop
|
stop_str, truncate_to = stop
|
||||||
if truncate_to != -1:
|
if truncate_to != -1:
|
||||||
self.output_text = self.output_text[:truncate_to]
|
self.output_text = self.output_text[:truncate_to]
|
||||||
finish_reason = RequestFinishedReason.STOP
|
finish_reason = FinishReason.STOP
|
||||||
stop_reason = stop_str
|
stop_reason = stop_str
|
||||||
|
|
||||||
# TODO: handle stop_token_ids here too?
|
# TODO: handle stop_token_ids here too?
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import prometheus_client
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.engine import RequestFinishedReason
|
from vllm.v1.engine import FinishReason
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -117,13 +117,13 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
documentation="Number of generation tokens processed.",
|
documentation="Number of generation tokens processed.",
|
||||||
labelnames=labelnames).labels(*labelvalues)
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
|
||||||
self.counter_request_success: Dict[RequestFinishedReason,
|
self.counter_request_success: Dict[FinishReason,
|
||||||
prometheus_client.Counter] = {}
|
prometheus_client.Counter] = {}
|
||||||
counter_request_success_base = prometheus_client.Counter(
|
counter_request_success_base = prometheus_client.Counter(
|
||||||
name="vllm:request_success_total",
|
name="vllm:request_success_total",
|
||||||
documentation="Count of successfully processed requests.",
|
documentation="Count of successfully processed requests.",
|
||||||
labelnames=labelnames + ["finished_reason"])
|
labelnames=labelnames + ["finished_reason"])
|
||||||
for reason in RequestFinishedReason:
|
for reason in FinishReason:
|
||||||
self.counter_request_success[
|
self.counter_request_success[
|
||||||
reason] = counter_request_success_base.labels(*(labelvalues +
|
reason] = counter_request_success_base.labels(*(labelvalues +
|
||||||
[str(reason)]))
|
[str(reason)]))
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.v1.engine import EngineCoreOutput, RequestFinishedReason
|
from vllm.v1.engine import EngineCoreOutput, FinishReason
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -32,7 +32,7 @@ class RequestStateStats:
|
|||||||
class FinishedRequestStats:
|
class FinishedRequestStats:
|
||||||
"""Stats associated with a finished request."""
|
"""Stats associated with a finished request."""
|
||||||
|
|
||||||
finish_reason: "RequestFinishedReason"
|
finish_reason: "FinishReason"
|
||||||
num_prompt_tokens: int = 0
|
num_prompt_tokens: int = 0
|
||||||
num_generation_tokens: int = 0
|
num_generation_tokens: int = 0
|
||||||
|
|
||||||
@@ -74,8 +74,7 @@ class IterationStats:
|
|||||||
request_state_stats.num_generation_tokens += num_new_generation_tokens
|
request_state_stats.num_generation_tokens += num_new_generation_tokens
|
||||||
request_state_stats.last_token_time = now
|
request_state_stats.last_token_time = now
|
||||||
|
|
||||||
def update_from_finished_request(self,
|
def update_from_finished_request(self, finish_reason: "FinishReason",
|
||||||
finish_reason: "RequestFinishedReason",
|
|
||||||
request_output: "RequestOutput",
|
request_output: "RequestOutput",
|
||||||
request_state_stats: RequestStateStats):
|
request_state_stats: RequestStateStats):
|
||||||
self.finished_requests.append(
|
self.finished_requests.append(
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import RequestMetrics
|
from vllm.sequence import RequestMetrics
|
||||||
from vllm.v1.engine import EngineCoreRequest, RequestFinishedReason
|
from vllm.v1.engine import EngineCoreRequest, FinishReason
|
||||||
from vllm.v1.utils import ConstantList
|
from vllm.v1.utils import ConstantList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -109,7 +109,7 @@ class Request:
|
|||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return RequestStatus.is_finished(self.status)
|
return RequestStatus.is_finished(self.status)
|
||||||
|
|
||||||
def get_finished_reason(self) -> Union[RequestFinishedReason, None]:
|
def get_finished_reason(self) -> Union[FinishReason, None]:
|
||||||
return RequestStatus.get_finished_reason(self.status)
|
return RequestStatus.get_finished_reason(self.status)
|
||||||
|
|
||||||
def has_encoder_inputs(self) -> bool:
|
def has_encoder_inputs(self) -> bool:
|
||||||
@@ -150,7 +150,7 @@ class RequestStatus(enum.IntEnum):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_finished_reason(
|
def get_finished_reason(
|
||||||
status: "RequestStatus") -> Union[RequestFinishedReason, None]:
|
status: "RequestStatus") -> Union[FinishReason, None]:
|
||||||
return _FINISHED_REASON_MAP.get(status)
|
return _FINISHED_REASON_MAP.get(status)
|
||||||
|
|
||||||
|
|
||||||
@@ -159,8 +159,8 @@ class RequestStatus(enum.IntEnum):
|
|||||||
# are longer than the model's length cap. Therefore, the stop
|
# are longer than the model's length cap. Therefore, the stop
|
||||||
# reason should also be "length" as in OpenAI API.
|
# reason should also be "length" as in OpenAI API.
|
||||||
_FINISHED_REASON_MAP = {
|
_FINISHED_REASON_MAP = {
|
||||||
RequestStatus.FINISHED_STOPPED: RequestFinishedReason.STOP,
|
RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
|
||||||
RequestStatus.FINISHED_LENGTH_CAPPED: RequestFinishedReason.LENGTH,
|
RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
|
||||||
RequestStatus.FINISHED_ABORTED: RequestFinishedReason.ABORT,
|
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
|
||||||
RequestStatus.FINISHED_IGNORED: RequestFinishedReason.LENGTH,
|
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user