[Kernel][Attention] Separate Attention.kv_scale into k_scale and v_scale (#6081)

This commit is contained in:
Michael Goin
2024-07-16 18:31:32 -04:00
committed by GitHub
parent 160e1d8c99
commit 978aed5300
33 changed files with 317 additions and 185 deletions

View File

@@ -7,19 +7,49 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
Fp8LinearMethod)
MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
"nm-testing/Phi-3-mini-128k-instruct-FP8",
]
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
def test_model_load_and_run(vllm_runner, model: str):
with vllm_runner(model) as llm:
@pytest.mark.parametrize("model_id", MODELS)
def test_model_load_and_run(vllm_runner, model_id: str):
with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)
print(outputs[0][1])
KV_CACHE_MODELS = [
# Deprecated AutoFP8 format using .kv_scale
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
# AutoFP8 format using separate .k_scale and .v_scale
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
]
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
attn = model.model.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
# these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],