Consolidate Intel Quantization Toolkit Integration in vLLM (#31716)

Signed-off-by: yiliu30 <yi4.liu@intel.com>
This commit is contained in:
Yi Liu
2026-01-14 15:11:30 +08:00
committed by GitHub
parent 6fa6e7ef0c
commit 50632adc58
10 changed files with 531 additions and 660 deletions

View File

@@ -5,12 +5,11 @@ Quantization trades off model precision for smaller memory footprint, allowing l
Contents:
- [AutoAWQ](auto_awq.md)
- [AutoRound](auto_round.md)
- [BitsAndBytes](bnb.md)
- [BitBLAS](bitblas.md)
- [GGUF](gguf.md)
- [GPTQModel](gptqmodel.md)
- [INC](inc.md)
- [Intel Neural Compressor](inc.md)
- [INT4 W4A16](int4.md)
- [INT8 W8A8](int8.md)
- [FP8 W8A8](fp8.md)
@@ -43,23 +42,23 @@ th:not(:first-child) {
}
</style>
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU |
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ |
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU |
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
- ✅︎ indicates that the quantization method is supported on the specified hardware.
- ❌ indicates that the quantization method is not supported on the specified hardware.
- All Intel Gaudi quantization support has been migrated to [vLLM-Gaudi](https://github.com/vllm-project/vllm-gaudi).
!!! note
For information on quantization support on Google TPU, please refer to the [TPU-Inference Recommended Models and Features](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/) documentation.

View File

@@ -1,103 +0,0 @@
# AutoRound
[AutoRound](https://github.com/intel/auto-round) is Intels advanced quantization algorithm designed to produce highly efficient **INT2, INT3, INT4, and INT8**
quantized large language models—striking an optimal balance between accuracy and deployment performance.
AutoRound applies weight-only quantization to transformer-based models, enabling significant memory savings and faster
inference while maintaining near-original accuracy. It supports a wide range of hardware platforms, including **CPUs,
Intel GPUs, HPUs, and CUDA-enabled devices**.
Please refer to the [AutoRound guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) for more details.
Key Features:
**AutoRound, AutoAWQ, AutoGPTQ, and GGUF** are supported
**10+ vision-language models (VLMs)** are supported
**Per-layer mixed-bit quantization** for fine-grained control
**RTN (Round-To-Nearest) mode** for quick quantization with slight accuracy loss
**Multiple quantization recipes**: best, base, and light
✅ Advanced utilities such as immediate packing and support for **10+ backends**
## Installation
```bash
uv pip install auto-round
```
## Quantizing a model
For VLMs, please change to `auto-round-mllm` in CLI usage and `AutoRoundMLLM` in API usage.
### CLI usage
```bash
auto-round \
--model Qwen/Qwen3-0.6B \
--bits 4 \
--group_size 128 \
--format "auto_round" \
--output_dir ./tmp_autoround
```
```bash
auto-round \
--model Qwen/Qwen3-0.6B \
--format "gguf:q4_k_m" \
--output_dir ./tmp_autoround
```
### API usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRound
model_name = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
bits, group_size, sym = 4, 128, True
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym)
# the best accuracy, 4-5X slower, low_gpu_mem_usage could save ~20G but ~30% slower
# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym)
# 2-3X speedup, slight accuracy drop at W4G128
# autoround = AutoRound(model, tokenizer, nsamples=128, iters=50, lr=5e-3, bits=bits, group_size=group_size, sym=sym )
output_dir = "./tmp_autoround"
# format= 'auto_round'(default), 'auto_gptq', 'auto_awq'
autoround.quantize_and_save(output_dir, format="auto_round")
```
## Running a quantized model with vLLM
Here is some example code to run auto-round format in vLLM:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(temperature=0.6, top_p=0.95)
model_name = "Intel/DeepSeek-R1-0528-Qwen3-8B-int4-AutoRound"
llm = LLM(model=model_name)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
## Acknowledgement
Special thanks to open-source low precision libraries such as AutoGPTQ, AutoAWQ, GPTQModel, Triton, Marlin, and
ExLLaMAV2 for providing low-precision CUDA kernels, which are leveraged in AutoRound.

View File

@@ -1,50 +1,89 @@
# FP8 INC
# Intel Quantization Support
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
Currently, quantization is validated only in Llama models.
[AutoRound](https://github.com/intel/auto-round) is Intels advanced quantization algorithm designed for large language models(LLMs). It produces highly efficient **INT2, INT3, INT4, INT8, MXFP8, MXFP4, NVFP4**, and **GGUF** quantized models, balancing accuracy and inference performance. AutoRound is also part of the [Intel® Neural Compressor](https://github.com/intel/neural-compressor). For a deeper introduction, see the [AutoRound step-by-step guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md).
Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to:
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).
## Key Features
!!! note
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
✅ Superior Accuracy Delivers strong performance even at 23 bits [example models](https://huggingface.co/collections/OPEA/2-3-bits)
!!! note
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference.
✅ Fast Mixed `Bits`/`Dtypes` Scheme Generation Automatically configure in minutes
## Run Online Inference Using FP8
✅ Support for exporting **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** formats
Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command:
**10+ vision-language models (VLMs)** are supported
**Per-layer mixed-bit quantization** for fine-grained control
**RTN (Round-To-Nearest) mode** for quick quantization with slight accuracy loss
**Multiple quantization recipes**: best, base, and light
✅ Advanced utilities such as immediate packing and support for **10+ backends**
## Supported Recipes on Intel Platforms
On Intel platforms, AutoRound recipes are being enabled progressively by format and hardware. Currently, vLLM supports:
- **`W4A16`**: weight-only, 4-bit weights with 16-bit activations
- **`W8A16`**: weight-only, 8-bit weights with 16-bit activations
Additional recipes and formats will be supported in future releases.
## Quantizing a Model
### Installation
```bash
export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor-parallel-size 8
uv pip install auto-round
```
!!! tip
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
`VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes.
### Quantize with CLI
## Run Offline Inference Using FP8
```bash
auto-round \
--model Qwen/Qwen3-0.6B \
--scheme W4A16 \
--format auto_round \
--output_dir ./tmp_autoround
```
To run offline inference (after completing the model calibration process):
* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode.
* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object.
* Call shutdown method of the model_executor at the end of the run.
### Quantize with Python API
```python
from vllm import LLM
llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc")
...
# Call llm.generate on the required prompts and sampling params.
...
llm.llm_engine.model_executor.shutdown()
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRound
model_name = "Qwen/Qwen3-0.6B"
autoround = AutoRound(model_name, scheme="W4A16")
# the best accuracy, 4-5X slower, low_gpu_mem_usage could save ~20G but ~30% slower
# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym)
# 2-3X speedup, slight accuracy drop at W4G128
# autoround = AutoRound(model, tokenizer, nsamples=128, iters=50, lr=5e-3, bits=bits, group_size=group_size, sym=sym )
output_dir = "./tmp_autoround"
# format= 'auto_round'(default), 'auto_gptq', 'auto_awq'
autoround.quantize_and_save(output_dir, format="auto_round")
```
## Device for the Model's Weights Uploading
## Deploying AutoRound Quantized Models in vLLM
The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution.
This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory.
```bash
vllm serve Intel/DeepSeek-R1-0528-Qwen3-8B-int4-AutoRound \
--gpu-memory-utilization 0.8 \
--max-model-len 4096
```
!!! note
To deploy `wNa16` models on Intel GPU/CPU, please add `--enforce-eager` for now.
## Evaluating the Quantized Model with vLLM
```bash
lm_eval --model vllm \
--model_args pretrained="Intel/DeepSeek-R1-0528-Qwen3-8B-int4-AutoRound,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enforce_eager=True" \
--tasks gsm8k \
--num_fewshot 5 \
--batch_size 128
```

View File

@@ -26,9 +26,7 @@ MODELS = [
)
@pytest.mark.parametrize("model", MODELS)
def test_auto_round(vllm_runner, model):
with vllm_runner(
model, enforce_eager=True, allow_deprecated_quantization=True
) as llm:
with vllm_runner(model, enforce_eager=True) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=8)
assert output
print(f"{output[0][1]}")

View File

@@ -884,6 +884,7 @@ class ModelConfig:
"gptq_bitblas",
"awq_marlin",
"ipex",
"inc",
"moe_wna16",
"modelopt",
"modelopt_fp4",

View File

@@ -223,10 +223,6 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints
def is_online_quantization(quantization: Any) -> bool:
return quantization in ["inc"]
NEEDS_HELP = (
any("--help" in arg for arg in sys.argv) # vllm SUBCOMMAND --help
or (argv0 := sys.argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND
@@ -1304,7 +1300,6 @@ class EngineArgs:
load_format=self.load_format,
download_dir=self.download_dir,
safetensors_load_strategy=self.safetensors_load_strategy,
device="cpu" if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,

View File

@@ -33,7 +33,6 @@ QuantizationMethods = Literal[
"quark",
"moe_wna16",
"torchao",
"auto-round",
"rtn",
"inc",
"mxfp4",
@@ -54,7 +53,6 @@ DEPRECATED_QUANTIZATION_METHODS = [
"hqq",
"experts_int8",
"ipex",
"auto-round",
"rtn",
"petit_nvfp4",
]
@@ -120,7 +118,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
# lazy import to avoid triggering `torch.compile` too early
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .auto_round import AutoRoundConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
from .bitblas import BitBLASConfig
@@ -174,8 +171,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"auto-round": AutoRoundConfig,
"rtn": RTNConfig,
"auto-round": INCConfig,
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,

View File

@@ -1,454 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fractions import Fraction
from typing import TYPE_CHECKING, Any
import regex as re
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
class AutoRoundConfig(QuantizationConfig):
"""Config class for AutoRound.
Reference: https://arxiv.org/pdf/2309.05516
"""
SUPPORTED_BITS = {2, 3, 4, 8}
SUPPORTED_DTYPES = {"int"}
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
SUPPORTED_BACKENDS = {
"auto",
"gptq",
"gptq:marlin",
"awq",
"awq:marlin",
"marlin",
"ipex",
}
def __init__(
self,
weight_bits: int,
group_size: int,
sym: bool = True,
packing_format: str = "auto_round:auto_gptq",
block_name_to_quantize: str | list[str] | None = None,
extra_config: dict[str, Any] | None = None,
data_type: str = "int",
backend: str = "auto",
) -> None:
super().__init__()
if weight_bits not in self.SUPPORTED_BITS:
raise ValueError(
f"Unsupported weight_bits: {weight_bits}, "
f"currently only support {self.SUPPORTED_BITS}."
)
if data_type not in self.SUPPORTED_DTYPES:
raise ValueError(
f"Unsupported data_type: {data_type}, "
f"currently only support {self.SUPPORTED_DTYPES}."
)
if packing_format not in self.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported packing_format: {packing_format}, "
f"currently only support {self.SUPPORTED_FORMATS}."
)
if backend not in self.SUPPORTED_BACKENDS:
raise ValueError(
f"Unsupported backend: {backend}, "
f"currently only support {self.SUPPORTED_BACKENDS}."
)
self.weight_bits = weight_bits
self.group_size = group_size
self.sym = sym
self.packing_format = packing_format
self.block_name_to_quantize = (
block_name_to_quantize.split(",")
if isinstance(block_name_to_quantize, str)
else block_name_to_quantize
)
self.extra_config = extra_config
self.data_type = data_type
self.backend = backend
self.pack_factor = Fraction(32, weight_bits)
def __repr__(self) -> str:
return (
f"AutoRoundConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, sym={self.sym})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "auto-round"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantization_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
return cls(
weight_bits=cls.get_from_keys(config, ["bits"]),
group_size=cls.get_from_keys(config, ["group_size"]),
sym=cls.get_from_keys(config, ["sym"]),
packing_format=cls.get_from_keys_or(
config, ["packing_format"], "auto_round:auto_gptq"
),
block_name_to_quantize=cls.get_from_keys_or(
config, ["block_name_to_quantize", "to_quant_block_names"], None
),
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"),
)
def get_layer_config(self, layer, layer_name: str):
def get_config(name: str, quantized: bool = True):
if not self.extra_config:
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)
# exact match first
if name in self.extra_config:
cfg = self.extra_config[name]
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
for pattern, cfg in self.extra_config.items():
if not isinstance(pattern, str) or not any(
c in REGEX_SPECIAL_CHARS for c in pattern
):
continue
try:
if re.search(re.compile(pattern), name) is not None:
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
except re.error:
# Invalid regex, ignore.
continue
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)
# 1. Exact match from config
if self.extra_config and layer_name in self.extra_config:
return get_config(layer_name)
# 2. Determine whether layer should be quantized
quantized = not isinstance(layer, ParallelLMHead)
if self.block_name_to_quantize:
quantized = any(
layer_name.startswith(name) for name in self.block_name_to_quantize
)
# 3. Handle fused MoE
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
moe_configs = [
get_config(name, quantized)
for name in self.extra_config
if name.startswith(layer_name)
]
if moe_configs:
if len(set(moe_configs)) == 1:
return moe_configs[0]
raise ValueError(
f"Fused MoE layer '{layer_name}' requires "
f"consistent quant config for all sub-layers"
)
# 4. Handle fused QKV or other patterns
if self.extra_config:
for fusion_key, sub_keys in self.packed_modules_mapping.items():
if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
sub_names = [
layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
]
sub_configs = [get_config(name, quantized) for name in sub_names]
if len(set(sub_configs)) == 1:
return sub_configs[0]
raise ValueError(
f"Fused module '{layer_name}' requires "
f"consistent quant config for {sub_names}"
)
# 5. Fallback or try a regular expression match
return get_config(layer_name, quantized)
def check_quantized(self, weight_bits: int) -> bool:
return weight_bits < 16
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.block_name_to_quantize is not None:
self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(
self.block_name_to_quantize
)
if self.extra_config is not None:
self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
)
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
AWQ_TYPE_MAP[weight_bits], group_size, not sym
)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMarlinLinearMethod,
AWQMarlinMoEMethod,
)
quant_args_marlin = AWQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
lm_head_quantized=False,
full_config={},
modules_to_not_convert=[],
)
else:
from vllm.model_executor.layers.quantization.awq import (
AWQConfig,
AWQLinearMethod,
)
quant_args = AWQConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
)
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMarlinMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
config = {
"quant_method": "awq",
"bits": weight_bits,
"group_size": group_size,
"zero_point": not sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return AWQMarlinLinearMethod(quant_args_marlin)
else:
return AWQLinearMethod(quant_args)
return None
def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
)
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
quant_args_marlin = GPTQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
is_sym=sym,
lm_head_quantized=False,
desc_act=False,
dynamic={},
full_config={},
)
else:
from vllm.model_executor.layers.quantization.gptq import (
GPTQConfig,
GPTQLinearMethod,
)
quant_args = GPTQConfig(
weight_bits=weight_bits,
group_size=group_size,
lm_head_quantized=False,
desc_act=False,
dynamic={},
)
if isinstance(layer, FusedMoE):
if use_marlin:
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
else:
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config,
)
config = {
"quant_method": "gptq",
"bits": weight_bits,
"group_size": group_size,
"sym": sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix
)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return GPTQMarlinLinearMethod(quant_args_marlin)
else:
return GPTQLinearMethod(quant_args)
return None
def apply_ipex_quant_layer(self, layer, prefix: str):
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
from vllm.model_executor.layers.quantization.ipex_quant import (
IPEXAWQLinearMethod,
IPEXConfig,
IPEXGPTQLinearMethod,
)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if "awq" in self.packing_format:
config = IPEXConfig(
method="awq", weight_bits=weight_bits, group_size=group_size
)
return IPEXAWQLinearMethod(config)
elif "gptq" in self.packing_format:
config = IPEXConfig(
method="gptq", weight_bits=weight_bits, group_size=group_size
)
return IPEXGPTQLinearMethod(config)
else:
raise ValueError(
f"ipex backend only supports awq "
f"and gtpq format,but got {self.packing_format}"
)
else:
return None
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
if prefix and self.extra_config:
for layer_name in self.extra_config:
if (
layer_name == prefix or layer_name == f"model.{prefix}"
) and self.extra_config[layer_name].get("bits", 16) >= 16:
return UnquantizedLinearMethod()
if (
current_platform.is_cpu()
or current_platform.is_xpu()
or self.backend == "ipex"
):
return self.apply_ipex_quant_layer(layer, prefix)
if "gptq" in self.packing_format or "gptq" in self.backend:
return self.apply_gptq_quant_layer(layer, prefix)
if "awq" in self.packing_format or "awq" in self.backend:
return self.apply_awq_quant_layer(layer, prefix)

View File

@@ -1,39 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Intel Gaudi supports quantization of various modules and functions,
# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`.
# During model loading,
# INC will patch layers with quantization/dequantization operators.
# Meanwhile, INC will convert original weight to target datatype
# and loading to target device.
# static scaling should be provided through Quant_CONFIG:
# `QUANT_CONFIG` is an environment variable,
# that points to the measurement or quantization JSON config file.
# The measurement configuration file is used during the calibration procedure,
# to collect measurements for a given model.
# The quantization configuration is used during inference.
# For more information, please refer to:
# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html
from typing import Any, Optional
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
UnquantizedFusedMoEMethod,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizeMethodBase,
QuantizationMethods,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
class INCConfig(QuantizationConfig):
"""Config class for FP8 using Intel Neural Compressor."""
"""Config class for Intel Neural Compressor (INC).
Repo: https://github.com/intel/neural-compressor
"""
SUPPORTED_BITS = {2, 3, 4, 8}
SUPPORTED_DTYPES = {"int"}
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
SUPPORTED_BACKENDS = {
"auto",
"gptq",
"gptq:marlin",
"awq",
"awq:marlin",
"marlin",
"ipex",
}
def __init__(
self,
weight_bits: int,
group_size: int,
sym: bool = True,
packing_format: str = "auto_round:auto_gptq",
block_name_to_quantize: str | list[str] | None = None,
extra_config: dict[str, Any] | None = None,
data_type: str = "int",
backend: str = "auto",
) -> None:
super().__init__()
if weight_bits not in self.SUPPORTED_BITS:
raise ValueError(
f"Unsupported weight_bits: {weight_bits}, "
f"currently only support {self.SUPPORTED_BITS}."
)
if data_type not in self.SUPPORTED_DTYPES:
raise ValueError(
f"Unsupported data_type: {data_type},"
f" currently only support {self.SUPPORTED_DTYPES}."
)
if packing_format not in self.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported packing_format: {packing_format}, "
f"currently only support {self.SUPPORTED_FORMATS}."
)
if backend not in self.SUPPORTED_BACKENDS:
raise ValueError(
f"Unsupported backend: {backend}, "
f"currently only support {self.SUPPORTED_BACKENDS}."
)
self.weight_bits = weight_bits
self.group_size = group_size
self.sym = sym
self.packing_format = packing_format
self.block_name_to_quantize = (
block_name_to_quantize.split(",")
if isinstance(block_name_to_quantize, str)
else block_name_to_quantize
)
self.extra_config = extra_config
self.data_type = data_type
self.backend = backend
self.pack_factor = Fraction(32, weight_bits)
def __repr__(self) -> str:
return (
f"INCConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, sym={self.sym})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
@@ -41,25 +100,365 @@ class INCConfig(QuantizationConfig):
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
raise AssertionError
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config)
return None
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise AssertionError
return 60
@staticmethod
def get_config_filenames() -> list[str]:
return []
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantization_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
return cls(
weight_bits=cls.get_from_keys(config, ["bits"]),
group_size=cls.get_from_keys(config, ["group_size"]),
sym=cls.get_from_keys(config, ["sym"]),
packing_format=cls.get_from_keys_or(
config, ["packing_format"], "auto_round:auto_gptq"
),
block_name_to_quantize=cls.get_from_keys_or(
config, ["block_name_to_quantize", "to_quant_block_names"], None
),
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"),
)
def get_layer_config(self, layer, layer_name: str):
def get_config(name: str, quantized: bool = True):
if not self.extra_config:
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)
# exact match first
if name in self.extra_config:
cfg = self.extra_config[name]
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
for pattern, cfg in self.extra_config.items():
if not isinstance(pattern, str) or not any(
c in REGEX_SPECIAL_CHARS for c in pattern
):
continue
try:
if re.search(re.compile(pattern), name) is not None:
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
except re.error:
# Invalid regex, ignore.
continue
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)
# 1. Exact match from config
if self.extra_config and layer_name in self.extra_config:
return get_config(layer_name)
# 2. Determine whether layer should be quantized
quantized = not isinstance(layer, ParallelLMHead)
if self.block_name_to_quantize:
quantized = any(
layer_name.startswith(name) for name in self.block_name_to_quantize
)
# 3. Handle fused MoE
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
moe_configs = [
get_config(name, quantized)
for name in self.extra_config
if name.startswith(layer_name)
]
if moe_configs:
if len(set(moe_configs)) == 1:
return moe_configs[0]
raise ValueError(
f"Fused MoE layer '{layer_name}' requires "
f"consistent quant config for all sub-layers"
)
# 4. Handle fused QKV or other patterns
if self.extra_config:
for fusion_key, sub_keys in self.packed_modules_mapping.items():
if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
sub_names = [
layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
]
sub_configs = [get_config(name, quantized) for name in sub_names]
if len(set(sub_configs)) == 1:
return sub_configs[0]
raise ValueError(
f"Fused module '{layer_name}' requires "
f"consistent quant config for {sub_names}"
)
# 5. Fallback or try a regular expression match
return get_config(layer_name, quantized)
def check_quantized(self, weight_bits: int) -> bool:
return weight_bits < 16
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.block_name_to_quantize is not None:
self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(
self.block_name_to_quantize
)
if self.extra_config is not None:
self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
)
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
AWQ_TYPE_MAP[weight_bits], group_size, not sym
)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMarlinLinearMethod,
AWQMarlinMoEMethod,
)
quant_args_marlin = AWQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
lm_head_quantized=False,
full_config={},
modules_to_not_convert=[],
)
else:
from vllm.model_executor.layers.quantization.awq import (
AWQConfig,
AWQLinearMethod,
)
quant_args = AWQConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
)
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
config = {
"quant_method": "awq",
"bits": weight_bits,
"group_size": group_size,
"zero_point": not sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return AWQMarlinLinearMethod(quant_args_marlin)
else:
return AWQLinearMethod(quant_args)
return None
def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
)
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
quant_args_marlin = GPTQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
is_sym=sym,
lm_head_quantized=False,
desc_act=False,
dynamic={},
full_config={},
)
else:
from vllm.model_executor.layers.quantization.gptq import (
GPTQConfig,
GPTQLinearMethod,
)
quant_args = GPTQConfig(
weight_bits=weight_bits,
group_size=group_size,
lm_head_quantized=False,
desc_act=False,
dynamic={},
)
if isinstance(layer, FusedMoE):
if use_marlin:
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
else:
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config,
)
config = {
"quant_method": "gptq",
"bits": weight_bits,
"group_size": group_size,
"sym": sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix
)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return GPTQMarlinLinearMethod(quant_args_marlin)
else:
return GPTQLinearMethod(quant_args)
return None
def apply_ipex_quant_layer(self, layer, prefix: str):
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
from vllm.model_executor.layers.quantization.ipex_quant import (
IPEXAWQLinearMethod,
IPEXConfig,
IPEXGPTQLinearMethod,
)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if "awq" in self.packing_format:
config = IPEXConfig(
method="awq", weight_bits=weight_bits, group_size=group_size
)
return IPEXAWQLinearMethod(config)
elif "gptq" in self.packing_format:
config = IPEXConfig(
method="gptq", weight_bits=weight_bits, group_size=group_size
)
return IPEXGPTQLinearMethod(config)
else:
raise ValueError(
f"ipex backend only supports awq "
f"and gptq format,but got {self.packing_format}"
)
else:
return None
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
if prefix and self.extra_config:
for layer_name in self.extra_config:
if (
layer_name == prefix or layer_name == f"model.{prefix}"
) and self.extra_config[layer_name].get("bits", 16) >= 16:
return UnquantizedLinearMethod()
if (
current_platform.is_cpu()
or current_platform.is_xpu()
or self.backend == "ipex"
):
return self.apply_ipex_quant_layer(layer, prefix)
if "gptq" in self.packing_format or "gptq" in self.backend:
return self.apply_gptq_quant_layer(layer, prefix)
if "awq" in self.packing_format or "awq" in self.backend:
return self.apply_awq_quant_layer(layer, prefix)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional["QuantizationMethods"]:
"""Override the `auto-round` method to `inc`."""
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
if is_auto_round_format:
return cls.get_name()
return None

View File

@@ -233,7 +233,7 @@ def get_quant_config(
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
if model_config.quantization in ("gguf", "inc"):
if model_config.quantization == "gguf":
return quant_cls()
# Read the quantization config from the HF model config, if available.