Merge EmbeddedLLM/vllm-rocm into vLLM main (#1836)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com> Co-authored-by: Amir Balwel <amoooori04@gmail.com> Co-authored-by: root <kuanfu.liu@akirakan.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: kuanfu <kuanfu.liu@embeddedllm.com> Co-authored-by: miloice <17350011+kliuae@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory
|
||||
from vllm.utils import get_cpu_memory, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -98,12 +98,27 @@ class ModelConfig:
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
if load_format not in [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]:
|
||||
supported_load_format = [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]
|
||||
rocm_not_supported_load_format = ["safetensors"]
|
||||
if load_format not in supported_load_format:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
if is_hip():
|
||||
if load_format in ["safetensors"]:
|
||||
rocm_supported_load_format = [
|
||||
f for f in supported_load_format
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format \'{load_format}\' is not supported in ROCm. "
|
||||
f"Supported load format are "
|
||||
f"{rocm_supported_load_format}")
|
||||
# Force ROCm to load from pt weights if nothing specific is set
|
||||
if load_format == "auto":
|
||||
load_format = "pt"
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
@@ -116,6 +131,7 @@ class ModelConfig:
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
rocm_not_supported_quantization = ["awq"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@@ -137,6 +153,11 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
) and self.quantization in rocm_not_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not supported "
|
||||
f"in ROCm.")
|
||||
logger.warning(f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
@@ -364,6 +385,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
@@ -393,6 +416,14 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
if is_hip() and torch_dtype == torch.float32:
|
||||
rocm_supported_dtypes = [
|
||||
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
||||
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||
]
|
||||
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
||||
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -73,7 +74,12 @@ def initialize_cluster(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
||||
|
||||
@@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
from vllm._C import ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.utils import is_hip
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@@ -160,6 +161,8 @@ class PagedAttention(nn.Module):
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
|
||||
@@ -7,6 +7,7 @@ from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
@@ -114,9 +115,14 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
lookup_table = weights["lookup_table"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
if is_hip():
|
||||
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
|
||||
out = out_f.to(dtype=torch.float16)
|
||||
else:
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
||||
@@ -10,6 +10,10 @@ from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import *
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
from vllm.utils import is_hip
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
@@ -39,6 +43,18 @@ _MODEL_REGISTRY = {
|
||||
"YiForCausalLM": YiForCausalLM,
|
||||
}
|
||||
|
||||
# Models to be disabled in ROCm
|
||||
_ROCM_UNSUPPORTED_MODELS = []
|
||||
if is_hip():
|
||||
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
|
||||
del _MODEL_REGISTRY[rocm_model]
|
||||
|
||||
# Models partially supported in ROCm
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
||||
"MistralForCausalLM":
|
||||
"Sliding window attention is not supported in ROCm's flash attention",
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
@@ -53,7 +69,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MODEL_REGISTRY:
|
||||
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
logger.warning(
|
||||
f"{arch} is not fully supported in ROCm. Reason: "
|
||||
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
|
||||
return _MODEL_REGISTRY[arch]
|
||||
elif arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model architecture {arch} is not supported by ROCm for now. \n"
|
||||
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
||||
|
||||
@@ -27,10 +27,14 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
max_shared_mem = cuda_utils.get_device_attribute(
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||
return int(max_shared_mem)
|
||||
|
||||
Reference in New Issue
Block a user