enable skipping of SW attention layers when using FP8 KV cache (#33695)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
@@ -466,3 +466,26 @@ def test_fp8_reloading(
|
||||
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
|
||||
"""Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
kv_cache_dtype="fp8",
|
||||
kv_cache_dtype_skip_layers=["0", "2"],
|
||||
enforce_eager=True,
|
||||
) as llm:
|
||||
|
||||
def check_layers(model):
|
||||
for i, layer in enumerate(model.model.decoder.layers):
|
||||
expected = "auto" if str(i) in ["0", "2"] else "fp8"
|
||||
assert layer.self_attn.attn.kv_cache_dtype == expected
|
||||
|
||||
llm.apply_model(check_layers)
|
||||
|
||||
Reference in New Issue
Block a user