Merge EmbeddedLLM/vllm-rocm into vLLM main (#1836)

Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Amir Balwel <amoooori04@gmail.com>
Co-authored-by: root <kuanfu.liu@akirakan.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: kuanfu <kuanfu.liu@embeddedllm.com>
Co-authored-by: miloice <17350011+kliuae@users.noreply.github.com>
This commit is contained in:
TJian
2023-12-08 15:16:52 +08:00
committed by GitHub
parent c8e7eb1eb3
commit 6ccc0bfffb
29 changed files with 873 additions and 118 deletions

View File

@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory
from vllm.utils import get_cpu_memory, is_hip
logger = init_logger(__name__)
@@ -98,12 +98,27 @@ class ModelConfig:
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
if load_format not in [
"auto", "pt", "safetensors", "npcache", "dummy"
]:
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
]
rocm_not_supported_load_format = ["safetensors"]
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
if is_hip():
if load_format in ["safetensors"]:
rocm_supported_load_format = [
f for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")
# Force ROCm to load from pt weights if nothing specific is set
if load_format == "auto":
load_format = "pt"
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None:
@@ -116,6 +131,7 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = ["awq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
@@ -137,6 +153,11 @@ class ModelConfig:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if is_hip(
) and self.quantization in rocm_not_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
@@ -364,6 +385,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
}
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
def _get_and_verify_dtype(
config: PretrainedConfig,
@@ -393,6 +416,14 @@ def _get_and_verify_dtype(
else:
raise ValueError(f"Unknown dtype: {dtype}")
if is_hip() and torch_dtype == torch.float32:
rocm_supported_dtypes = [
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32: