diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 6465985f0..b162b469b 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -187,7 +187,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( tensor_parallel_size=tp_size, max_num_seqs=128, max_model_len=8192, - dtype="bfloat16", # not everything is supported + dtype="auto", # not everything is supported gpu_memory_utilization=0.9, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, attention_config={"backend": backend}, @@ -400,7 +400,7 @@ def test_simple_generation(backend): tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), gpu_memory_utilization=0.9, max_model_len=2048, - dtype="bfloat16", + dtype="auto", enable_prefix_caching=False, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, attention_config={"backend": backend}, @@ -466,7 +466,7 @@ def test_logprobs_without_batch_invariance_should_fail( tensor_parallel_size=tp_size, max_num_seqs=32, max_model_len=8192, - dtype="bfloat16", + dtype="auto", enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, attention_config={"backend": backend}, ) @@ -686,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs( tensor_parallel_size=tp_size, max_num_seqs=32, max_model_len=8192, - dtype="bfloat16", + dtype="auto", enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, attention_config={"backend": backend}, ) @@ -931,7 +931,7 @@ def LLM_with_max_seqs( max_num_seqs=max_num_seqs, gpu_memory_utilization=gpu_memory_utilization, max_model_len=max_model_len, - dtype="bfloat16", + dtype="auto", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), enable_prefix_caching=False, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 54b36d103..98fd6be8f 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +from vllm.utils.mem_utils import get_max_shared_memory_bytes from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -177,7 +178,7 @@ def matmul_persistent( }, torch.float16: { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": _fp16_block_size_n, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, @@ -700,7 +701,7 @@ def bmm_batch_invariant(a, b, *, out=None): }, torch.float16: { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": _fp16_block_size_n, "BLOCK_SIZE_K": 64, "num_stages": 3, "num_warps": 8, @@ -752,7 +753,8 @@ def addmm_batch_invariant(bias, a, b): def _log_softmax_batch_invariant(input, dim, _half_to_float): - assert not _half_to_float, "not implemented" + if _half_to_float: + return log_softmax(input.float(), dim=dim) return log_softmax(input, dim=dim) @@ -923,12 +925,15 @@ _original_fp16_reduction_precision = None _original_bf16_reduction_precision = None _original_cublas_workspace_cfg = None _original_cublaslt_workspace_size = None +_fp16_block_size_n = 256 def enable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm global _original_fp16_reduction_precision, _original_bf16_reduction_precision global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size + global _fp16_block_size_n + if _batch_invariant_MODE: return @@ -944,6 +949,10 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") + + # Query the shared memory size and set block size + # accordingly to avoid triton OutOfResources + _fp16_block_size_n = 256 if get_max_shared_memory_bytes() > 106496 else 128 else: # Only source of batch invariance for Hopper is split-k, can disable through # cuBLAS workspace config diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 58bb75d0a..37cffcb3d 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -8,6 +8,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers import PretrainedConfig from vllm import _custom_ops as ops +from vllm import envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import ( @@ -273,8 +274,9 @@ class AWQLinearMethod(LinearMethodBase): # num_tokens >= threshold FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 - - if FP16_MATMUL_HEURISTIC_CONDITION: + # Batch invariant mode requires torch.matmul path + # for Triton override + if FP16_MATMUL_HEURISTIC_CONDITION or envs.VLLM_BATCH_INVARIANT: out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) out = torch.matmul(reshaped_x, out) else: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eff571ef2..be3001a7f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,6 +10,7 @@ from transformers import PretrainedConfig import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops +from vllm import envs from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( MPLinearLayerConfig, @@ -233,6 +234,11 @@ class AWQMarlinConfig(QuantizationConfig): def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> "QuantizationMethods | None": + # Skip override to marlin kernels, as they are not + # batch invariant + if envs.VLLM_BATCH_INVARIANT: + return None + can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = ( user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"