Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -40,8 +40,9 @@ def query_marlin_supported_quant_types(
):
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int()
)
if device_capability < 80:
return []
@@ -50,10 +51,12 @@ def query_marlin_supported_quant_types(
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if has_zp is None:
types0 = query_marlin_supported_quant_types(False, include_fp_type,
device_capability)
types1 = query_marlin_supported_quant_types(True, include_fp_type,
device_capability)
types0 = query_marlin_supported_quant_types(
False, include_fp_type, device_capability
)
types1 = query_marlin_supported_quant_types(
True, include_fp_type, device_capability
)
return types0 + types1
if has_zp:
@@ -68,108 +71,126 @@ def query_marlin_supported_quant_types(
def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None,
) -> tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int()
)
supported_types = query_marlin_supported_quant_types(
has_zp, True, device_capability)
has_zp, True, device_capability
)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
return (
False,
f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).",
)
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (
False,
f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.",
)
return True, None
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
device_capability)
def check_marlin_supported(
quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
device_capability: Optional[int] = None,
) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
return cond
def verify_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False) -> None:
def verify_marlin_supported(
quant_type: ScalarType, group_size: int, has_zp: bool = False
) -> None:
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError(err_msg)
def verify_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) -> None:
def verify_marlin_supports_shape(
output_size_per_partition: int,
input_size_per_partition: int,
input_size: int,
group_size: int,
) -> None:
# Validate output_size_per_partition
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
raise ValueError(f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
# Validate input_size_per_partition
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
if (group_size < input_size
and input_size_per_partition % group_size != 0):
if group_size < input_size and input_size_per_partition % group_size != 0:
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
"with --quantization gptq."
)
def check_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> tuple[bool, Optional[str]]:
def check_marlin_supports_shape(
output_size_per_partition: int,
input_size_per_partition: int,
input_size: int,
group_size: int,
) -> tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
group_size)
verify_marlin_supports_shape(
output_size_per_partition, input_size_per_partition, input_size, group_size
)
except ValueError as e:
return False, e.__str__()
return True, None
def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
output_size_per_partition = getattr(layer, "output_size_per_partition",
None) or layer.output_size
input_size_per_partition = getattr(layer, "input_size_per_partition",
None) or layer.input_size
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
output_size_per_partition = (
getattr(layer, "output_size_per_partition", None) or layer.output_size
)
input_size_per_partition = (
getattr(layer, "input_size_per_partition", None) or layer.input_size
)
return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size)[0]
group_size=group_size,
)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
@@ -180,51 +201,58 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
supports_shape = hidden_size % 128 == 0 and \
intermediate_size_per_partition % max(64, group_size) == 0
supports_shape = (
hidden_size % 128 == 0
and intermediate_size_per_partition % max(64, group_size) == 0
)
supports_group_size = group_size in [-1, 32, 64, 128]
return supports_shape and supports_group_size and \
supports_router_weight and supports_activation
return (
supports_shape
and supports_group_size
and supports_router_weight
and supports_activation
)
def marlin_moe_intermediate_size(w1_packed: torch.Tensor,
w2_packed: torch.Tensor):
def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor):
"""
Given Marlin packed weight matrices w1_packed, and w2_packed,
return the MoE intermediate size N
return the MoE intermediate size N
"""
marlin_tile_size = 16
return w2_packed.size(1) * marlin_tile_size
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
def marlin_make_workspace(
output_size_per_partition: int, device: torch.device
) -> torch.Tensor:
max_workspace_size = (
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
) * GPTQ_MARLIN_MAX_PARALLEL
return torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
return torch.zeros(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)
def marlin_make_workspace_new(device: torch.device,
max_blocks_per_sm: int = 1) -> torch.Tensor:
def marlin_make_workspace_new(
device: torch.device, max_blocks_per_sm: int = 1
) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(sms * max_blocks_per_sm,
dtype=torch.int,
device=device,
requires_grad=False)
return torch.zeros(
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
def marlin_repeat_scales_on_all_ranks(
act_order: bool, group_size: int, is_row_parallel: bool
) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
@@ -232,17 +260,18 @@ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_sort_g_idx(
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
@@ -253,14 +282,13 @@ def get_scale_perms():
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int) -> torch.Tensor:
def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int
) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
@@ -296,8 +324,9 @@ def marlin_moe_permute_scales(
return output
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm, _ = get_scale_perms()
@@ -318,8 +347,9 @@ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
return zp
def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int) -> torch.Tensor:
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
@@ -341,8 +371,9 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp
def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int):
def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
@@ -350,8 +381,7 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
num_bits)
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
return output
@@ -363,7 +393,8 @@ def maybe_warn_marlin_atomic_add(device, dtype):
logger.info_once(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible.")
"if possible."
)
def maybe_warn_marlin_atomic_add_env():
@@ -375,12 +406,13 @@ def maybe_warn_marlin_atomic_add_env():
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
)
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
dtype: torch.dtype) -> bool:
def should_use_atomic_add_reduce(
m: int, n: int, k: int, device: torch.device, dtype: torch.dtype
) -> bool:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if n >= 2048 or k < 2048 or device.type != "cuda":
@@ -402,88 +434,98 @@ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)
def apply_awq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)