[Bugfix] Fix AWQ models batch invariance issues (#38670)
Signed-off-by: yusuf <yusuf@deeplearningmachine.mynet> Signed-off-by: <> Co-authored-by: yusuf <yusuf@deeplearningmachine.mynet>
This commit is contained in:
@@ -10,6 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
@@ -177,7 +178,7 @@ def matmul_persistent(
|
||||
},
|
||||
torch.float16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_N": _fp16_block_size_n,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
@@ -700,7 +701,7 @@ def bmm_batch_invariant(a, b, *, out=None):
|
||||
},
|
||||
torch.float16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_N": _fp16_block_size_n,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
@@ -752,7 +753,8 @@ def addmm_batch_invariant(bias, a, b):
|
||||
|
||||
|
||||
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
assert not _half_to_float, "not implemented"
|
||||
if _half_to_float:
|
||||
return log_softmax(input.float(), dim=dim)
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
@@ -923,12 +925,15 @@ _original_fp16_reduction_precision = None
|
||||
_original_bf16_reduction_precision = None
|
||||
_original_cublas_workspace_cfg = None
|
||||
_original_cublaslt_workspace_size = None
|
||||
_fp16_block_size_n = 256
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
|
||||
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
|
||||
global _fp16_block_size_n
|
||||
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
@@ -944,6 +949,10 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
|
||||
# Query the shared memory size and set block size
|
||||
# accordingly to avoid triton OutOfResources
|
||||
_fp16_block_size_n = 256 if get_max_shared_memory_bytes() > 106496 else 128
|
||||
else:
|
||||
# Only source of batch invariance for Hopper is split-k, can disable through
|
||||
# cuBLAS workspace config
|
||||
|
||||
@@ -8,6 +8,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -273,8 +274,9 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
# Batch invariant mode requires torch.matmul path
|
||||
# for Triton override
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION or envs.VLLM_BATCH_INVARIANT:
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
|
||||
@@ -10,6 +10,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
@@ -233,6 +234,11 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> "QuantizationMethods | None":
|
||||
# Skip override to marlin kernels, as they are not
|
||||
# batch invariant
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
return None
|
||||
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
|
||||
|
||||
Reference in New Issue
Block a user