[Bugfix] Fix inappropriate content of model_name tag in Prometheus metrics (#3937)

This commit is contained in:
DearPlanet
2024-05-05 06:39:34 +08:00
committed by GitHub
parent 021b1a2ab7
commit 4302987069
5 changed files with 76 additions and 14 deletions

View File

@@ -1,7 +1,7 @@
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
@@ -21,6 +21,7 @@ def nullable_str(val: str):
class EngineArgs:
"""Arguments for vLLM engine."""
model: str
served_model_name: Optional[Union[List[str]]] = None
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
@@ -489,6 +490,21 @@ class EngineArgs:
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument(
"--served-model-name",
nargs="+",
type=str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument. Noted that this name(s)"
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one.")
return parser
@classmethod
@@ -508,7 +524,7 @@ class EngineArgs:
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init)
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,