[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-07-11 16:01:13 +08:00
committed by GitHub
parent 762be26a8e
commit 8020e98c9f
8 changed files with 561 additions and 88 deletions

View File

@@ -14,7 +14,7 @@ from transformers import BitsAndBytesConfig
from tests.quantization.utils import is_quant_method_supported
from ...utils import compare_two_settings, multi_gpu_test
from ..utils import check_embeddings_close
from ..utils import check_embeddings_close, check_logprobs_close
models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"),
@@ -26,6 +26,10 @@ models_4bit_to_embedding_test = [
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
]
models_4bit_to_moe_test = [
("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"),
]
models_pre_qaunt_4bit_to_test = [
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
@@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
compare_two_settings(model_name, common_args, pp_args)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test)
def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
))
with vllm_runner(model_name,
quantization='bitsandbytes',
enforce_eager=False) as llm:
vllm_outputs = llm.generate_greedy_logprobs(example_prompts,
max_tokens=32,
num_logprobs=5)
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
transformers_outputs = llm.generate_greedy_logprobs_limit(
example_prompts, max_tokens=32, num_logprobs=5)
check_logprobs_close(
outputs_0_lst=transformers_outputs,
outputs_1_lst=vllm_outputs,
name_0="transformers",
name_1="vllm",
)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
@@ -182,7 +215,8 @@ def validate_generated_texts(hf_runner,
model_name,
pre_quant=False,
hf_model_kwargs=None,
vllm_tp_size=1):
vllm_tp_size=1,
max_tokens=8):
# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference
@@ -190,7 +224,8 @@ def validate_generated_texts(hf_runner,
quantization=None if pre_quant else 'bitsandbytes',
tensor_parallel_size=vllm_tp_size,
enforce_eager=False) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
# Clean up the GPU memory for the next test
@@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
# Run with HF runner
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
hf_outputs = llm.generate_greedy(prompts, 8)
hf_outputs = llm.generate_greedy(prompts, max_tokens)
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
# Clean up the GPU memory for the next test
gc.collect()
torch.cuda.empty_cache()
# Compare the generated strings
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]
assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"