[Bugfix] Expand quantization method support in perf metrics (#37231)
Signed-off-by: Thillai Chithambaram <thillaichithambaram.a@gmail.com>
This commit is contained in:
committed by
GitHub
parent
577df69b26
commit
828f862acb
@@ -7,6 +7,7 @@ Tests for the analytic estimators in metrics/flops.py.
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
|
||||
from transformers.models.llama4.configuration_llama4 import (
|
||||
Llama4Config,
|
||||
@@ -21,10 +22,12 @@ from vllm.transformers_utils.model_arch_config_convertor import (
|
||||
ModelArchConfigConvertorBase,
|
||||
)
|
||||
from vllm.v1.metrics.perf import (
|
||||
_QUANT_WEIGHT_BYTE_SIZE,
|
||||
AttentionMetrics,
|
||||
BaseConfigParser,
|
||||
ExecutionContext,
|
||||
FfnMetrics,
|
||||
InvalidComponent,
|
||||
ModelMetrics,
|
||||
ParsedArgs,
|
||||
UnembedMetrics,
|
||||
@@ -905,3 +908,116 @@ def test_attention_per_gpu_heads_not_evenly_divisible():
|
||||
assert per_gpu_flops > 0
|
||||
assert global_flops > 0
|
||||
assert global_flops > per_gpu_flops
|
||||
|
||||
|
||||
# INT4 / FP4 quantization methods (weight_byte_size == 0.5)
|
||||
_INT4_FP4_METHODS = [m for m, s in _QUANT_WEIGHT_BYTE_SIZE.items() if s == 0.5]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_method", _INT4_FP4_METHODS)
|
||||
def test_quantization_config_parser_int4_methods(quant_method):
|
||||
"""Test quantization parsers with INT4/FP4 methods (0.5 bytes)."""
|
||||
|
||||
class MockQuantConfig:
|
||||
def get_name(self):
|
||||
return quant_method
|
||||
|
||||
hf_config = Qwen3Config(
|
||||
hidden_size=2048,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=1,
|
||||
)
|
||||
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
|
||||
|
||||
attn_result = AttentionMetrics.get_parser().parse(vllm_config)
|
||||
assert attn_result.weight_byte_size == 0.5, (
|
||||
f"Expected 0.5 for {quant_method}, got {attn_result.weight_byte_size}"
|
||||
)
|
||||
|
||||
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
|
||||
assert ffn_result.weight_byte_size == 0.5, (
|
||||
f"Expected 0.5 for {quant_method}, got {ffn_result.weight_byte_size}"
|
||||
)
|
||||
|
||||
|
||||
# FP8 / INT8 quantization methods (weight_byte_size == 1)
|
||||
_FP8_INT8_METHODS = [m for m, s in _QUANT_WEIGHT_BYTE_SIZE.items() if s == 1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_method", _FP8_INT8_METHODS)
|
||||
def test_quantization_config_parser_fp8_methods(quant_method):
|
||||
"""Test quantization parsers with FP8/INT8 methods (1 byte)."""
|
||||
|
||||
class MockQuantConfig:
|
||||
def get_name(self):
|
||||
return quant_method
|
||||
|
||||
hf_config = Qwen3Config(
|
||||
hidden_size=2048,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=1,
|
||||
)
|
||||
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
|
||||
|
||||
attn_result = AttentionMetrics.get_parser().parse(vllm_config)
|
||||
assert attn_result.weight_byte_size == 1, (
|
||||
f"Expected 1 for {quant_method}, got {attn_result.weight_byte_size}"
|
||||
)
|
||||
|
||||
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
|
||||
assert ffn_result.weight_byte_size == 1, (
|
||||
f"Expected 1 for {quant_method}, got {ffn_result.weight_byte_size}"
|
||||
)
|
||||
|
||||
|
||||
def test_quantization_config_parser_unknown_method():
|
||||
"""Test that an unrecognized quant method raises InvalidComponent."""
|
||||
|
||||
class MockQuantConfig:
|
||||
def get_name(self):
|
||||
return "unknown_quant_method"
|
||||
|
||||
hf_config = Qwen3Config(
|
||||
hidden_size=2048,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=1,
|
||||
)
|
||||
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
|
||||
|
||||
with pytest.raises(InvalidComponent):
|
||||
AttentionMetrics.get_parser().parse(vllm_config)
|
||||
|
||||
with pytest.raises(InvalidComponent):
|
||||
FfnMetrics.get_parser().parse(vllm_config)
|
||||
|
||||
|
||||
def test_quantized_model_metrics_aggregation():
|
||||
"""Test that ModelMetrics works end-to-end with a quantized model config."""
|
||||
|
||||
class MockQuantConfig:
|
||||
def get_name(self):
|
||||
return "gptq"
|
||||
|
||||
hf_config = Qwen3Config(
|
||||
hidden_size=2048,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=12,
|
||||
vocab_size=32000,
|
||||
intermediate_size=8192,
|
||||
)
|
||||
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
|
||||
|
||||
model_metrics = ModelMetrics(vllm_config)
|
||||
ctx = ExecutionContext.from_single_request(
|
||||
num_tokens=100, context_len=512, is_prefill=True
|
||||
)
|
||||
|
||||
# Should not crash and should produce valid metrics
|
||||
total_flops = model_metrics.get_num_flops(ctx)
|
||||
breakdown = model_metrics.get_num_flops_breakdown(ctx)
|
||||
|
||||
assert total_flops > 0
|
||||
assert total_flops == sum(breakdown.values())
|
||||
|
||||
@@ -40,6 +40,42 @@ class InvalidComponent(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Mapping from quantization method name to effective weight byte size.
|
||||
# Used by both AttentionQuantizationConfigParser and
|
||||
# FfnQuantizationConfigParser to determine the weight_byte_size for
|
||||
# flops/memory estimation.
|
||||
#
|
||||
# NOTE: Methods like GPTQ and BitsAndBytes support variable bit-widths
|
||||
# (e.g., 4-bit and 8-bit). We default to 4-bit (0.5 bytes) since this
|
||||
# is by far the most common configuration.
|
||||
_QUANT_WEIGHT_BYTE_SIZE: dict[str, float] = {
|
||||
# FP8 methods (1 byte per weight)
|
||||
"fp8": 1,
|
||||
"fbgemm_fp8": 1,
|
||||
"ptpc_fp8": 1,
|
||||
"fp_quant": 1,
|
||||
"modelopt": 1,
|
||||
"modelopt_mxfp8": 1,
|
||||
# FP4 / INT4 methods (0.5 bytes per weight)
|
||||
"mxfp4": 0.5,
|
||||
"awq": 0.5,
|
||||
"awq_marlin": 0.5,
|
||||
"gptq": 0.5,
|
||||
"gptq_marlin": 0.5,
|
||||
"bitsandbytes": 0.5,
|
||||
"modelopt_fp4": 0.5,
|
||||
"petit_nvfp4": 0.5,
|
||||
"gguf": 0.5,
|
||||
"compressed-tensors": 0.5,
|
||||
"torchao": 0.5,
|
||||
"quark": 0.5,
|
||||
"moe_wna16": 0.5,
|
||||
"inc": 0.5,
|
||||
"cpu_awq": 0.5,
|
||||
"experts_int8": 1,
|
||||
}
|
||||
|
||||
|
||||
#### Basic Data Types ####
|
||||
|
||||
|
||||
@@ -350,17 +386,12 @@ class AttentionQuantizationConfigParser(Parser):
|
||||
return args
|
||||
|
||||
quant_method = cfg.get_name()
|
||||
if quant_method in ["fp8", "fbgemm_fp8"]:
|
||||
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
|
||||
# FIXME: These configs also have concept of "ignored layers" and we
|
||||
# need to solve the same problem as above.
|
||||
args.weight_byte_size = 1
|
||||
elif quant_method == "mxfp4":
|
||||
# FIXME: Also has "ignored layers" issue above
|
||||
args.weight_byte_size = 0.5
|
||||
if quant_method in _QUANT_WEIGHT_BYTE_SIZE:
|
||||
args.weight_byte_size = _QUANT_WEIGHT_BYTE_SIZE[quant_method]
|
||||
else:
|
||||
# FIXME: Add more parsing logic for different quant methods.
|
||||
raise InvalidComponent
|
||||
raise InvalidComponent(
|
||||
f"Unsupported quantization method for attention metrics: {quant_method}"
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
@@ -617,19 +648,12 @@ class FfnQuantizationConfigParser(Parser):
|
||||
return args
|
||||
|
||||
quant_method = cfg.get_name()
|
||||
if quant_method in ["fp8", "fbgemm_fp8"]:
|
||||
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
|
||||
# (there might be more quantization methods for fp8).
|
||||
# FIXME: These configs also have concept of "ignored layers" and we
|
||||
# need to solve the same problem as above.
|
||||
args.weight_byte_size = 1
|
||||
pass
|
||||
elif quant_method == "mxfp4":
|
||||
# FIXME: Also has "ignored layers" issue above
|
||||
args.weight_byte_size = 0.5
|
||||
if quant_method in _QUANT_WEIGHT_BYTE_SIZE:
|
||||
args.weight_byte_size = _QUANT_WEIGHT_BYTE_SIZE[quant_method]
|
||||
else:
|
||||
# FIXME: Add more parsing logic for different quant methods.
|
||||
raise InvalidComponent
|
||||
raise InvalidComponent(
|
||||
f"Unsupported quantization method for FFN metrics: {quant_method}"
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user