Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,12 +9,16 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import (QuantizationConfig,
QuantizationMethods)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter,
)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -24,15 +28,12 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
scalar_types.uint4b8, scalar_types.uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
class GPTQMarlin24Config(QuantizationConfig):
"""Config class for Marlin24.
"""
"""Config class for Marlin24."""
def __init__(
self,
@@ -48,17 +49,18 @@ class GPTQMarlin24Config(QuantizationConfig):
self.group_size = group_size
# Verify
if quant_type is None or \
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
if quant_type is None or quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
raise ValueError(
f"Marlin_24 does not support quant_type = {quant_type}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
"are supported.")
"are supported."
)
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin_24 does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
"are supported."
)
self.quant_type = quant_type
@@ -83,7 +85,8 @@ class GPTQMarlin24Config(QuantizationConfig):
def __repr__(self) -> str:
return "Marlin24Config(quant_type={}, group_size={})".format(
self.quant_type, self.group_size)
self.quant_type, self.group_size
)
@classmethod
def get_name(cls) -> QuantizationMethods:
@@ -110,23 +113,26 @@ class GPTQMarlin24Config(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24"
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_marlin_24")
is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "gptq_marlin_24"
)
if is_marlin_24_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. "
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
msg = "The model is serialized in {} format. Using {} kernel.".format(
cls.get_name(), cls.get_name()
)
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["GPTQMarlin24LinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlin24LinearMethod(self)
return None
@@ -156,7 +162,8 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
f"The params dtype must be float16, but got {params_dtype}"
)
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
@@ -164,38 +171,46 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
f"min_n_threads = {self.quant_config.min_n_threads}."
)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
f"pack_factor = {self.quant_config.pack_factor}."
)
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
f"min_k_threads = {self.quant_config.min_k_threads}."
)
if (
self.quant_config.group_size != -1
and input_size_per_partition % self.quant_config.group_size != 0
):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
self.quant_config.tile_size**2
)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
raise ValueError("Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
output_size_per_partition
* self.quant_config.tile_size
// self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
@@ -204,55 +219,57 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)
weight_loader=weight_loader,
)
# Meta
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
meta = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader,
)
# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
input_groups = (
1
if self.quant_config.group_size == -1
else input_size_per_partition // self.quant_config.group_size
)
weight_scale_args = {
"data":
torch.empty(
"data": torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
"weight_loader":
weight_loader
"weight_loader": weight_loader,
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
output_size_per_partition // self.quant_config.min_n_threads
) * self.quant_config.max_parallel
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)
workspace = BasevLLMParameter(
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
weight_loader=weight_loader,
)
layer.register_parameter("B_24", qweight)
layer.register_parameter("B_meta", meta)
@@ -283,12 +300,19 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.quant_type,
size_m, size_n, size_k)
output_2d = ops.gptq_marlin_24_gemm(
x_2d,
qweight,
meta,
scales,
workspace,
self.quant_config.quant_type,
size_m,
size_n,
size_k,
)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
if bias is not None:
output.add_(bias) # In-place add