[Bugfix] Fix AWQ models batch invariance issues (#38670)

Signed-off-by: yusuf <yusuf@deeplearningmachine.mynet>
Signed-off-by: <>
Co-authored-by: yusuf <yusuf@deeplearningmachine.mynet>
This commit is contained in:
Yusuf Mohammad
2026-04-03 15:54:15 +01:00
committed by GitHub
parent 6b4872240f
commit 46f02e00f2
4 changed files with 27 additions and 10 deletions

View File

@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.mem_utils import get_max_shared_memory_bytes
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@@ -177,7 +178,7 @@ def matmul_persistent(
},
torch.float16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_N": _fp16_block_size_n,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
@@ -700,7 +701,7 @@ def bmm_batch_invariant(a, b, *, out=None):
},
torch.float16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_N": _fp16_block_size_n,
"BLOCK_SIZE_K": 64,
"num_stages": 3,
"num_warps": 8,
@@ -752,7 +753,8 @@ def addmm_batch_invariant(bias, a, b):
def _log_softmax_batch_invariant(input, dim, _half_to_float):
assert not _half_to_float, "not implemented"
if _half_to_float:
return log_softmax(input.float(), dim=dim)
return log_softmax(input, dim=dim)
@@ -923,12 +925,15 @@ _original_fp16_reduction_precision = None
_original_bf16_reduction_precision = None
_original_cublas_workspace_cfg = None
_original_cublaslt_workspace_size = None
_fp16_block_size_n = 256
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
global _fp16_block_size_n
if _batch_invariant_MODE:
return
@@ -944,6 +949,10 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
# Query the shared memory size and set block size
# accordingly to avoid triton OutOfResources
_fp16_block_size_n = 256 if get_max_shared_memory_bytes() > 106496 else 128
else:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config

View File

@@ -8,6 +8,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -273,8 +274,9 @@ class AWQLinearMethod(LinearMethodBase):
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
# Batch invariant mode requires torch.matmul path
# for Triton override
if FP16_MATMUL_HEURISTIC_CONDITION or envs.VLLM_BATCH_INVARIANT:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:

View File

@@ -10,6 +10,7 @@ from transformers import PretrainedConfig
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
MPLinearLayerConfig,
@@ -233,6 +234,11 @@ class AWQMarlinConfig(QuantizationConfig):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> "QuantizationMethods | None":
# Skip override to marlin kernels, as they are not
# batch invariant
if envs.VLLM_BATCH_INVARIANT:
return None
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"