Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ Note: Marlin internally uses locks to synchronize the threads. This can
|
||||
result in very slight nondeterminism for Marlin. As a result, we re-run the test
|
||||
up to 3 times to see if we pass.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
@@ -26,20 +27,20 @@ MAX_MODEL_LEN = 1024
|
||||
MODELS = [
|
||||
# act_order==True, group_size=128
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
|
||||
|
||||
# 8-bit, act_order==True, group_size=channelwise
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
|
||||
|
||||
# 4-bit, act_order==True, group_size=128
|
||||
("TechxGenus/gemma-1.1-2b-it-GPTQ", "main")
|
||||
("TechxGenus/gemma-1.1-2b-it-GPTQ", "main"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin")
|
||||
or current_platform.is_rocm()
|
||||
or not current_platform.is_cuda(),
|
||||
reason="gptq_marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin")
|
||||
or current_platform.is_rocm()
|
||||
or not current_platform.is_cuda(),
|
||||
reason="gptq_marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@@ -55,29 +56,34 @@ def test_models(
|
||||
model_name, revision = model
|
||||
|
||||
# Run marlin.
|
||||
with vllm_runner(model_name=model_name,
|
||||
revision=revision,
|
||||
dtype=dtype,
|
||||
quantization="marlin",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=1) as gptq_marlin_model:
|
||||
|
||||
with vllm_runner(
|
||||
model_name=model_name,
|
||||
revision=revision,
|
||||
dtype=dtype,
|
||||
quantization="marlin",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=1,
|
||||
) as gptq_marlin_model:
|
||||
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
example_prompts[:-1], max_tokens, num_logprobs
|
||||
)
|
||||
_ROPE_DICT.clear() # clear rope cache to avoid rope dtype error
|
||||
|
||||
# Run gptq.
|
||||
# The naive gptq kernel doesn't support bf16 yet.
|
||||
# Here we always compare fp16/bf16 gpt marlin kernel
|
||||
# to fp16 gptq kernel.
|
||||
with vllm_runner(model_name=model_name,
|
||||
revision=revision,
|
||||
dtype="half",
|
||||
quantization="gptq",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=1) as gptq_model:
|
||||
with vllm_runner(
|
||||
model_name=model_name,
|
||||
revision=revision,
|
||||
dtype="half",
|
||||
quantization="gptq",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=1,
|
||||
) as gptq_model:
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
example_prompts[:-1], max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=gptq_outputs,
|
||||
|
||||
Reference in New Issue
Block a user