[TPU] support fp8 kv cache quantization (#19292)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-07-19 20:01:00 -07:00
committed by GitHub
parent 2b504eb770
commit 3a1d8940ae
6 changed files with 94 additions and 27 deletions

View File

@@ -15,15 +15,18 @@ import pytest
from vllm.platforms import current_platform
MODEL_NAMES = [
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen3-1.7B",
"google/gemma-3-1b-it",
]
FP8_KV_MODEL_NAMES = [
"Qwen/Qwen3-1.7B",
]
NUM_CONCURRENT = 500
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUES = {
"Qwen/Qwen2-1.5B-Instruct": 0.58,
"Qwen/Qwen3-1.7B": 0.68,
"google/gemma-3-1b-it": 0.25,
}
@@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
if current_platform.is_tpu():
# Limit compilation time for TPU V1
if model == "google/gemma-3-1b-it":
# TPU + google/gemma-3-1b-it + xet doesn't work well.
m.setenv("HF_HUB_DISABLE_XET", "1")
# xet doesn't work well for both Qwen/Qwen3-1.7B and
# google/gemma-3-1b-it
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided)
@@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
run_test(model, more_args)
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
"""Run with the V0 Engine."""
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU")
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
model, monkeypatch: pytest.MonkeyPatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
run_test("Qwen/Qwen2-1.5B-Instruct")
m.setenv("VLLM_USE_V1", "1")
more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
# xet doesn't work well for Qwen/Qwen3-1.7B
m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
# Add TP test (if provided)
if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR)
run_test(model, more_args)