Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,10 +9,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinLinearMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod)
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
@@ -31,20 +31,20 @@ def test_lm_head(
|
||||
) -> None:
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
with vllm_runner(model_id, dtype=torch.float16,
|
||||
max_model_len=2048) as vllm_model:
|
||||
with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as vllm_model:
|
||||
|
||||
def check_model(model):
|
||||
lm_head_layer = model.lm_head
|
||||
if lm_head_quantized:
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod))
|
||||
assert isinstance(
|
||||
lm_head_layer.quant_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod),
|
||||
)
|
||||
else:
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
UnquantizedEmbeddingMethod)
|
||||
assert isinstance(
|
||||
lm_head_layer.quant_method, UnquantizedEmbeddingMethod
|
||||
)
|
||||
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
print(
|
||||
vllm_model.generate_greedy(["Hello my name is"],
|
||||
max_tokens=10)[0][1])
|
||||
print(vllm_model.generate_greedy(["Hello my name is"], max_tokens=10)[0][1])
|
||||
|
||||
Reference in New Issue
Block a user