Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,21 +10,18 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class VerifyAndUpdateConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
raise NotImplementedError
class Gemma3TextModelConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
hf_config = vllm_config.model_config.hf_config
@@ -32,7 +29,6 @@ class Gemma3TextModelConfig:
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
@@ -48,12 +44,11 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
"rope_scaling": getattr(config, "rope_scaling", None),
}
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
@@ -62,7 +57,6 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
@@ -76,29 +70,27 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
"rope_scaling": getattr(config, "rope_scaling", None),
}
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")
config.position_embedding_type = getattr(
config, "position_embedding_type", "rope"
)
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
config.qkv_proj_bias)
assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
@@ -117,7 +109,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
"rope_scaling": getattr(config, "rope_scaling", None),
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
@@ -125,15 +117,18 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if (not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None):
if (
not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None
):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(vllm_config.model_config.max_model_len,
max_trained_positions)
max_model_len = min(
vllm_config.model_config.max_model_len, max_trained_positions
)
vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
@@ -141,7 +136,9 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before, vllm_config.model_config.max_model_len)
max_model_len_before,
vllm_config.model_config.max_model_len,
)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
@@ -151,7 +148,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len)
"max_model_len", vllm_config.model_config.max_model_len
)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
@@ -173,7 +171,6 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
@@ -183,7 +180,6 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
@@ -193,27 +189,26 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
is_original_qwen3_reranker = getattr(config,
"is_original_qwen3_reranker",
False)
is_original_qwen3_reranker = getattr(
config, "is_original_qwen3_reranker", False
)
if not is_original_qwen3_reranker:
return
tokens = getattr(config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
assert tokens is not None and len(tokens) == 2, (
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
)
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
@@ -224,7 +219,6 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
@@ -240,12 +234,11 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
"rope_scaling": getattr(config, "rope_scaling", None),
}
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
structured_outputs_config = vllm_config.structured_outputs_config
@@ -268,12 +261,11 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
cuda_graph_sizes += [i for i in range(256, 993, 16)]
scheduler_config.cuda_graph_sizes = cuda_graph_sizes
logger.info(
"Overriding max cuda graph capture size to "
"%d for performance.", 992)
"Overriding max cuda graph capture size to %d for performance.", 992
)
class MambaModelConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
@@ -305,22 +297,26 @@ class MambaModelConfig(VerifyAndUpdateConfig):
]
if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS:
logger.info("Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "
"Please report any issues you may observe.")
logger.info(
"Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "
"Please report any issues you may observe."
)
else:
logger.info("Hybrid or mamba-based model detected without "
"support for prefix caching: disabling.")
logger.info(
"Hybrid or mamba-based model detected without "
"support for prefix caching: disabling."
)
cache_config.enable_prefix_caching = False
# TODO(tdoublep): remove once cascade attention is supported
logger.info("Disabling cascade attention since it is not supported "
"for hybrid models.")
logger.info(
"Disabling cascade attention since it is not supported for hybrid models."
)
model_config.disable_cascade_attn = True
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
@@ -354,7 +350,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype).page_size_bytes
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
@@ -385,10 +382,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# easily by changing the way we layout chunks in the
# mamba2 kernels.
chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = \
cdiv(mamba_page_size, attn_page_size_1_token)
attn_block_size = chunk_size * \
cdiv(attn_tokens_per_mamba_state, chunk_size)
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
@@ -398,23 +393,21 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size,
16 * attn_page_size_1_token)
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if (cache_config.block_size is None
or cache_config.block_size < attn_block_size):
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size)
attn_block_size,
)
# compute new attention page size
attn_page_size = \
cache_config.block_size * attn_page_size_1_token
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
@@ -423,19 +416,23 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
return
# pad mamba page size to exactly match attention
if (cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size):
cache_config.mamba_page_size_padded = (attn_page_size)
mamba_padding_pct = 100 * (attn_page_size -
mamba_page_size) / mamba_page_size
if (
cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size
):
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = (
100 * (attn_page_size - mamba_page_size) / mamba_page_size
)
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.", mamba_padding_pct)
"exactly equal.",
mamba_padding_pct,
)
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
@@ -450,8 +447,9 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
# "auto")
cache_config = vllm_config.cache_config
if cache_config.cache_dtype == "auto" or \
cache_config.cache_dtype.startswith("fp8"):
if cache_config.cache_dtype == "auto" or cache_config.cache_dtype.startswith(
"fp8"
):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":