Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -19,21 +19,26 @@ def enable_pickle(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
|
||||
# TODO: provide a small publicly available test checkpoint
|
||||
model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
|
||||
"TinyLlama-1.1B-Chat-v1.0-fp8-0710")
|
||||
model_path = (
|
||||
"/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
|
||||
"TinyLlama-1.1B-Chat-v1.0-fp8-0710"
|
||||
)
|
||||
|
||||
# Skip test if checkpoint doesn't exist
|
||||
if not os.path.exists(model_path):
|
||||
pytest.skip(f"Test checkpoint not found at {model_path}. "
|
||||
"This test requires a local ModelOpt FP8 checkpoint.")
|
||||
pytest.skip(
|
||||
f"Test checkpoint not found at {model_path}. "
|
||||
"This test requires a local ModelOpt FP8 checkpoint."
|
||||
)
|
||||
|
||||
with vllm_runner(model_path, quantization="modelopt",
|
||||
enforce_eager=True) as llm:
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
@@ -45,11 +50,12 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
|
||||
# Check that ModelOpt quantization method is properly applied
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8LinearMethod)
|
||||
ModelOptFp8LinearMethod,
|
||||
)
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod)
|
||||
assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod)
|
||||
assert isinstance(gate_up_proj.quant_method,
|
||||
ModelOptFp8LinearMethod)
|
||||
assert isinstance(gate_up_proj.quant_method, ModelOptFp8LinearMethod)
|
||||
assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod)
|
||||
|
||||
# Check weight dtype is FP8
|
||||
@@ -59,23 +65,23 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
assert down_proj.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
# Check scales are present and have correct dtype
|
||||
assert hasattr(qkv_proj, 'weight_scale')
|
||||
assert hasattr(qkv_proj, 'input_scale')
|
||||
assert hasattr(qkv_proj, "weight_scale")
|
||||
assert hasattr(qkv_proj, "input_scale")
|
||||
assert qkv_proj.weight_scale.dtype == torch.float32
|
||||
assert qkv_proj.input_scale.dtype == torch.float32
|
||||
|
||||
assert hasattr(o_proj, 'weight_scale')
|
||||
assert hasattr(o_proj, 'input_scale')
|
||||
assert hasattr(o_proj, "weight_scale")
|
||||
assert hasattr(o_proj, "input_scale")
|
||||
assert o_proj.weight_scale.dtype == torch.float32
|
||||
assert o_proj.input_scale.dtype == torch.float32
|
||||
|
||||
assert hasattr(gate_up_proj, 'weight_scale')
|
||||
assert hasattr(gate_up_proj, 'input_scale')
|
||||
assert hasattr(gate_up_proj, "weight_scale")
|
||||
assert hasattr(gate_up_proj, "input_scale")
|
||||
assert gate_up_proj.weight_scale.dtype == torch.float32
|
||||
assert gate_up_proj.input_scale.dtype == torch.float32
|
||||
|
||||
assert hasattr(down_proj, 'weight_scale')
|
||||
assert hasattr(down_proj, 'input_scale')
|
||||
assert hasattr(down_proj, "weight_scale")
|
||||
assert hasattr(down_proj, "input_scale")
|
||||
assert down_proj.weight_scale.dtype == torch.float32
|
||||
assert down_proj.input_scale.dtype == torch.float32
|
||||
|
||||
|
||||
Reference in New Issue
Block a user