[Kernel][Attention] Separate Attention.kv_scale into k_scale and v_scale (#6081)
This commit is contained in:
@@ -7,19 +7,49 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
|
||||
Fp8LinearMethod)
|
||||
|
||||
MODELS = [
|
||||
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
|
||||
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||
"nm-testing/Phi-3-mini-128k-instruct-FP8",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_model_load_and_run(vllm_runner, model: str):
|
||||
with vllm_runner(model) as llm:
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
def test_model_load_and_run(vllm_runner, model_id: str):
|
||||
with vllm_runner(model_id) as llm:
|
||||
# note: this does not test accuracy, just that we can run through
|
||||
# see lm-eval tests for accuracy
|
||||
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
KV_CACHE_MODELS = [
|
||||
# Deprecated AutoFP8 format using .kv_scale
|
||||
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||
# AutoFP8 format using separate .k_scale and .v_scale
|
||||
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
|
||||
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
attn = model.model.layers[0].self_attn.attn
|
||||
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||
# NOTE: it is valid for scales to be 1.0 (default value), but we know
|
||||
# these checkpoints have scales < 1.0
|
||||
assert 0.0 < attn._k_scale < 1.0
|
||||
assert 0.0 < attn._v_scale < 1.0
|
||||
|
||||
# note: this does not test accuracy, just that we can run through
|
||||
# see lm-eval tests for accuracy
|
||||
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||
|
||||
Reference in New Issue
Block a user