[Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922)
Co-authored-by: alexm <alexm@neuralmagic.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -10,12 +10,12 @@ up to 3 times to see if we pass.
|
||||
|
||||
Run `pytest tests/models/test_marlin.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
@@ -55,43 +55,24 @@ def test_models(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
|
||||
marlin_model = vllm_runner(model_pair.model_marlin,
|
||||
dtype=dtype,
|
||||
quantization="marlin")
|
||||
marlin_outputs = marlin_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||
# does not free the GPU memory. On Ampere, deleting the just model
|
||||
# frees the memory.
|
||||
del marlin_model
|
||||
|
||||
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
|
||||
gptq_model = vllm_runner(model_pair.model_gptq,
|
||||
dtype=dtype,
|
||||
quantization="gptq")
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||
# does not free the GPU memory. On Ampere, deleting the just model
|
||||
# frees the memory.
|
||||
del gptq_model
|
||||
|
||||
# loop through the prompts
|
||||
for prompt_idx in range(len(example_prompts)):
|
||||
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
|
||||
prompt_idx]
|
||||
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
|
||||
prompt_idx]
|
||||
|
||||
for idx, (gptq_output_id, marlin_output_id) in enumerate(
|
||||
zip(gptq_output_ids, marlin_output_ids)):
|
||||
# If sequence is not an exact match,
|
||||
if marlin_output_id != gptq_output_id:
|
||||
# Each predicted token must be in top 5 of the other's
|
||||
assert gptq_output_id in marlin_logprobs[idx], (
|
||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
|
||||
f"Marlin:\t{marlin_output_str!r}")
|
||||
assert marlin_output_id in gptq_logprobs[idx], (
|
||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
|
||||
f"Marlin:\t{marlin_output_str!r}")
|
||||
|
||||
# Break out since sequences will now diverge.
|
||||
break
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=gptq_outputs,
|
||||
outputs_1_lst=marlin_outputs,
|
||||
name_0="gptq",
|
||||
name_1="marlin",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user