[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||
get_hf_image_processor_config,
|
||||
get_hf_text_config)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
is_hip, print_warning_once)
|
||||
print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@@ -43,7 +43,7 @@ class ModelConfig:
|
||||
|
||||
Args:
|
||||
model: Name or path of the huggingface model to use.
|
||||
It is also used as the content for `model_name` tag in metrics
|
||||
It is also used as the content for `model_name` tag in metrics
|
||||
output when `served_model_name` is not specified.
|
||||
task: The task to use the model for. Each vLLM instance only supports
|
||||
one task, even if the same model can be used for multiple tasks.
|
||||
@@ -99,15 +99,15 @@ class ModelConfig:
|
||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||
detokenizer.
|
||||
served_model_name: The model name used in metrics tag `model_name`,
|
||||
matches the model name exposed via the APIs. If multiple model
|
||||
names provided, the first name will be used. If not specified,
|
||||
matches the model name exposed via the APIs. If multiple model
|
||||
names provided, the first name will be used. If not specified,
|
||||
the model name will be the same as `model`.
|
||||
limit_mm_per_prompt: Maximum number of data instances per modality
|
||||
limit_mm_per_prompt: Maximum number of data instances per modality
|
||||
per prompt. Only applicable for multimodal models.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
can not be gathered from the vllm arguments.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
can not be gathered from the vllm arguments.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
@@ -350,7 +350,7 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
if current_platform.is_rocm(
|
||||
) and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
@@ -365,7 +365,7 @@ class ModelConfig:
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.", self.quantization)
|
||||
if (self.quantization == "awq" and is_hip()
|
||||
if (self.quantization == "awq" and current_platform.is_rocm()
|
||||
and not envs.VLLM_USE_TRITON_AWQ):
|
||||
logger.warning(
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
@@ -385,7 +385,7 @@ class ModelConfig:
|
||||
|
||||
def _verify_bnb_config(self) -> None:
|
||||
"""
|
||||
The current version of bitsandbytes (0.44.0) with 8-bit models does not
|
||||
The current version of bitsandbytes (0.44.0) with 8-bit models does not
|
||||
yet support CUDA graph.
|
||||
"""
|
||||
is_bitsandbytes = self.quantization == "bitsandbytes"
|
||||
@@ -810,7 +810,7 @@ class LoadConfig:
|
||||
fast weight loading.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
|
||||
"""
|
||||
@@ -843,7 +843,8 @@ class LoadConfig:
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
if current_platform.is_rocm(
|
||||
) and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f for f in LoadFormat.__members__
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
@@ -967,7 +968,7 @@ class ParallelConfig:
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
@@ -996,7 +997,7 @@ class SchedulerConfig:
|
||||
prompt latency) before scheduling next prompt.
|
||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens.
|
||||
preemption_mode: Whether to perform preemption by swapping or
|
||||
preemption_mode: Whether to perform preemption by swapping or
|
||||
recomputation. If not specified, we determine the mode as follows:
|
||||
We use recomputation by default since it incurs lower overhead than
|
||||
swapping. However, when the sequence group has multiple sequences
|
||||
@@ -1215,7 +1216,7 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
@@ -1225,7 +1226,7 @@ class SpeculativeConfig:
|
||||
If set to False, token log probabilities are returned
|
||||
according to the log probability settings in SamplingParams.
|
||||
If not specified, it defaults to True.
|
||||
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
@@ -1470,13 +1471,13 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
disable_logprobs: If set to True, token log probabilities will not
|
||||
be returned even if requested by sampling parameters. This
|
||||
be returned even if requested by sampling parameters. This
|
||||
reduces latency by skipping logprob calculation in proposal
|
||||
sampling, target sampling, and after accepted tokens are
|
||||
determined. If set to False, log probabilities will be
|
||||
@@ -1843,10 +1844,10 @@ def get_min_sliding_window(
|
||||
def get_served_model_name(model: str,
|
||||
served_model_name: Optional[Union[str, List[str]]]):
|
||||
"""
|
||||
If the input is a non-empty list, the first model_name in
|
||||
`served_model_name` is taken.
|
||||
If the input is a non-empty string, it is used directly.
|
||||
For cases where the input is either an empty string or an
|
||||
If the input is a non-empty list, the first model_name in
|
||||
`served_model_name` is taken.
|
||||
If the input is a non-empty string, it is used directly.
|
||||
For cases where the input is either an empty string or an
|
||||
empty list, the fallback is to use `self.model`.
|
||||
"""
|
||||
if not served_model_name:
|
||||
|
||||
Reference in New Issue
Block a user