Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,21 +10,18 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VerifyAndUpdateConfig:
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Gemma3TextModelConfig:
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
@@ -32,7 +29,6 @@ class Gemma3TextModelConfig:
|
||||
|
||||
|
||||
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -48,12 +44,11 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rope_theta,
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
|
||||
|
||||
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
@@ -62,7 +57,6 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
|
||||
|
||||
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -76,29 +70,27 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "NomicBertConfig"
|
||||
assert config.activation_function in ["swiglu", "gelu"]
|
||||
config.position_embedding_type = getattr(config,
|
||||
"position_embedding_type",
|
||||
"rope")
|
||||
config.position_embedding_type = getattr(
|
||||
config, "position_embedding_type", "rope"
|
||||
)
|
||||
|
||||
if config.activation_function == "swiglu":
|
||||
config.hidden_act = "silu"
|
||||
else:
|
||||
config.hidden_act = config.activation_function
|
||||
|
||||
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
|
||||
config.qkv_proj_bias)
|
||||
assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
|
||||
config.bias = config.qkv_proj_bias
|
||||
|
||||
assert config.rotary_emb_scale_base is None
|
||||
@@ -117,7 +109,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
"rotary_dim": rotary_emb_dim,
|
||||
"max_position": max_trained_positions,
|
||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
|
||||
# we ignore config.rotary_scaling_factor so that for datasets shorter
|
||||
@@ -125,15 +117,18 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
# with SentenceTransformer.
|
||||
# The context extension uses vllm style rope_theta and rope_scaling.
|
||||
# See #17785 #18755
|
||||
if (not vllm_config.model_config.hf_overrides
|
||||
and vllm_config.model_config.original_max_model_len is None):
|
||||
if (
|
||||
not vllm_config.model_config.hf_overrides
|
||||
and vllm_config.model_config.original_max_model_len is None
|
||||
):
|
||||
# Default
|
||||
# Reset max_model_len to max_trained_positions.
|
||||
# nomic-embed-text-v2-moe the length is set to 512
|
||||
# by sentence_bert_config.json.
|
||||
max_model_len_before = vllm_config.model_config.max_model_len
|
||||
max_model_len = min(vllm_config.model_config.max_model_len,
|
||||
max_trained_positions)
|
||||
max_model_len = min(
|
||||
vllm_config.model_config.max_model_len, max_trained_positions
|
||||
)
|
||||
|
||||
vllm_config.recalculate_max_model_len(max_model_len)
|
||||
logger.warning(
|
||||
@@ -141,7 +136,9 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
"Changing max_model_len from %s to %s. "
|
||||
"To enable context extension, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
|
||||
max_model_len_before, vllm_config.model_config.max_model_len)
|
||||
max_model_len_before,
|
||||
vllm_config.model_config.max_model_len,
|
||||
)
|
||||
else:
|
||||
# We need to re-verify max_model_len to avoid lengths
|
||||
# greater than position_embedding.
|
||||
@@ -151,7 +148,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
if isinstance(model_config.hf_overrides, dict):
|
||||
# hf_overrides_kw
|
||||
max_model_len = model_config.hf_overrides.get(
|
||||
"max_model_len", vllm_config.model_config.max_model_len)
|
||||
"max_model_len", vllm_config.model_config.max_model_len
|
||||
)
|
||||
else:
|
||||
# hf_overrides_fn
|
||||
# This might be overridden by sentence_bert_config.json.
|
||||
@@ -173,7 +171,6 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
|
||||
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
@@ -183,7 +180,6 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
|
||||
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
@@ -193,27 +189,26 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
|
||||
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
is_original_qwen3_reranker = getattr(config,
|
||||
"is_original_qwen3_reranker",
|
||||
False)
|
||||
is_original_qwen3_reranker = getattr(
|
||||
config, "is_original_qwen3_reranker", False
|
||||
)
|
||||
|
||||
if not is_original_qwen3_reranker:
|
||||
return
|
||||
|
||||
tokens = getattr(config, "classifier_from_token", None)
|
||||
assert tokens is not None and len(tokens) == 2, \
|
||||
("Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
||||
assert tokens is not None and len(tokens) == 2, (
|
||||
"Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
|
||||
)
|
||||
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
|
||||
|
||||
|
||||
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -224,7 +219,6 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
|
||||
|
||||
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -240,12 +234,11 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rope_theta,
|
||||
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
|
||||
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
structured_outputs_config = vllm_config.structured_outputs_config
|
||||
@@ -268,12 +261,11 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
cuda_graph_sizes += [i for i in range(256, 993, 16)]
|
||||
scheduler_config.cuda_graph_sizes = cuda_graph_sizes
|
||||
logger.info(
|
||||
"Overriding max cuda graph capture size to "
|
||||
"%d for performance.", 992)
|
||||
"Overriding max cuda graph capture size to %d for performance.", 992
|
||||
)
|
||||
|
||||
|
||||
class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
@@ -305,22 +297,26 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
]
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.architecture in MAMBA2_MODELS:
|
||||
logger.info("Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba2 layers is experimental. "
|
||||
"Please report any issues you may observe.")
|
||||
logger.info(
|
||||
"Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba2 layers is experimental. "
|
||||
"Please report any issues you may observe."
|
||||
)
|
||||
else:
|
||||
logger.info("Hybrid or mamba-based model detected without "
|
||||
"support for prefix caching: disabling.")
|
||||
logger.info(
|
||||
"Hybrid or mamba-based model detected without "
|
||||
"support for prefix caching: disabling."
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
# TODO(tdoublep): remove once cascade attention is supported
|
||||
logger.info("Disabling cascade attention since it is not supported "
|
||||
"for hybrid models.")
|
||||
logger.info(
|
||||
"Disabling cascade attention since it is not supported for hybrid models."
|
||||
)
|
||||
model_config.disable_cascade_attn = True
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
@@ -354,7 +350,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype).page_size_bytes
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
@@ -385,10 +382,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
chunk_size = model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = \
|
||||
cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
attn_block_size = chunk_size * \
|
||||
cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
@@ -398,23 +393,21 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size,
|
||||
16 * attn_page_size_1_token)
|
||||
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
# too small.
|
||||
if (cache_config.block_size is None
|
||||
or cache_config.block_size < attn_block_size):
|
||||
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size)
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = \
|
||||
cache_config.block_size * attn_page_size_1_token
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
@@ -423,19 +416,23 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size):
|
||||
cache_config.mamba_page_size_padded = (attn_page_size)
|
||||
mamba_padding_pct = 100 * (attn_page_size -
|
||||
mamba_page_size) / mamba_page_size
|
||||
if (
|
||||
cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size
|
||||
):
|
||||
cache_config.mamba_page_size_padded = attn_page_size
|
||||
mamba_padding_pct = (
|
||||
100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||
)
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.", mamba_padding_pct)
|
||||
"exactly equal.",
|
||||
mamba_padding_pct,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
@@ -450,8 +447,9 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
||||
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
|
||||
# "auto")
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.cache_dtype == "auto" or \
|
||||
cache_config.cache_dtype.startswith("fp8"):
|
||||
if cache_config.cache_dtype == "auto" or cache_config.cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
|
||||
if cache_config.cache_dtype == "bfloat16":
|
||||
|
||||
Reference in New Issue
Block a user