[MISC] Consolidate FP8 kv-cache tests (#8131)
This commit is contained in:
@@ -16,18 +16,6 @@ MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
E5M2_KV_MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
]
|
||||
E4M3_KV_MODELS = [
|
||||
"meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
||||
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
||||
]
|
||||
KV_CACHE_QUANTIZATION_PATHS = {
|
||||
"meta-llama/Llama-2-7b-chat-hf":
|
||||
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@@ -78,10 +66,10 @@ def test_models(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype,model",
|
||||
[("fp8_e5m2", m)
|
||||
for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
|
||||
for m in E4M3_KV_MODELS])
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype,model",
|
||||
[("fp8_e4m3",
|
||||
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")])
|
||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16])
|
||||
@@ -104,30 +92,15 @@ def test_models_with_fp8_kv_cache(
|
||||
disable_async_output_proc: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Only checks log probs match between chunked-prefill and
|
||||
non-chunked-prefill version of vLLM model runner.
|
||||
|
||||
This test is used when there is discrepancy in kernels
|
||||
/ numerics (e.g. when using lower-precision types like FP8).
|
||||
Check output logprobs match between no_chunked_prefill and chunked_prefill
|
||||
with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py,
|
||||
so here we only check chunked prefill.
|
||||
"""
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
if model == "facebook/opt-125m":
|
||||
pytest.skip(
|
||||
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
|
||||
)
|
||||
if ((model, kv_cache_dtype, chunked_prefill_token_size) == (
|
||||
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", "fp8_e4m3", 4)):
|
||||
pytest.skip("flakey test, see: #7874 #8051")
|
||||
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
extra_kwargs = {}
|
||||
if model in KV_CACHE_QUANTIZATION_PATHS:
|
||||
extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
|
||||
model]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
@@ -135,7 +108,6 @@ def test_models_with_fp8_kv_cache(
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
**extra_kwargs,
|
||||
) as vllm_model:
|
||||
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
@@ -149,7 +121,6 @@ def test_models_with_fp8_kv_cache(
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
**extra_kwargs,
|
||||
) as vllm_model:
|
||||
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
Reference in New Issue
Block a user