[Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922)

Co-authored-by: alexm <alexm@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Robert Shaw
2024-04-29 12:35:34 -04:00
committed by GitHub
parent df29793dc7
commit 73c8d677e5
14 changed files with 2627 additions and 105 deletions

View File

@@ -10,12 +10,12 @@ up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from dataclasses import dataclass
import pytest
import torch
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability()
@@ -55,43 +55,24 @@ def test_models(
max_tokens: int,
num_logprobs: int,
) -> None:
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
marlin_model = vllm_runner(model_pair.model_marlin,
dtype=dtype,
quantization="marlin")
marlin_outputs = marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del marlin_model
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
gptq_model = vllm_runner(model_pair.model_gptq,
dtype=dtype,
quantization="gptq")
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del gptq_model
# loop through the prompts
for prompt_idx in range(len(example_prompts)):
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
prompt_idx]
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
prompt_idx]
for idx, (gptq_output_id, marlin_output_id) in enumerate(
zip(gptq_output_ids, marlin_output_ids)):
# If sequence is not an exact match,
if marlin_output_id != gptq_output_id:
# Each predicted token must be in top 5 of the other's
assert gptq_output_id in marlin_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
f"Marlin:\t{marlin_output_str!r}")
assert marlin_output_id in gptq_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
f"Marlin:\t{marlin_output_str!r}")
# Break out since sequences will now diverge.
break
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=marlin_outputs,
name_0="gptq",
name_1="marlin",
)