[CI Failure] Fix NomicBert max_model_len validation (#31662)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -113,8 +113,8 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "NomicBertConfig"
|
||||
assert config.activation_function in ["swiglu", "gelu"]
|
||||
@@ -137,6 +137,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
config.intermediate_size = config.n_inner
|
||||
config.hidden_size = config.n_embd
|
||||
config.num_hidden_layers = config.n_layer
|
||||
model_config.model_arch_config.hidden_size = config.hidden_size
|
||||
model_config.model_arch_config.total_num_hidden_layers = (
|
||||
config.num_hidden_layers
|
||||
)
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
max_trained_positions = getattr(config, "max_trained_positions", 2048)
|
||||
@@ -153,42 +157,43 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
# The context extension uses vllm style rope_theta and rope_parameters.
|
||||
# See #17785 #18755
|
||||
if (
|
||||
not vllm_config.model_config.hf_overrides
|
||||
and vllm_config.model_config.original_max_model_len is None
|
||||
not model_config.hf_overrides
|
||||
and model_config.original_max_model_len is None
|
||||
):
|
||||
# Default
|
||||
# Reset max_model_len to max_trained_positions.
|
||||
# nomic-embed-text-v2-moe the length is set to 512
|
||||
# by sentence_bert_config.json.
|
||||
max_model_len_before = vllm_config.model_config.max_model_len
|
||||
max_model_len = min(
|
||||
vllm_config.model_config.max_model_len, max_trained_positions
|
||||
max_model_len_before = model_config.max_model_len
|
||||
max_model_len = min(model_config.max_model_len, max_trained_positions)
|
||||
|
||||
model_config.max_model_len = model_config.get_and_verify_max_len(
|
||||
max_model_len
|
||||
)
|
||||
|
||||
vllm_config.recalculate_max_model_len(max_model_len)
|
||||
logger.warning(
|
||||
"Nomic context extension is disabled. "
|
||||
"Changing max_model_len from %s to %s. "
|
||||
"To enable context extension, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
|
||||
max_model_len_before,
|
||||
vllm_config.model_config.max_model_len,
|
||||
)
|
||||
if model_config.max_model_len != max_model_len_before:
|
||||
logger.warning(
|
||||
"Nomic context extension is disabled. "
|
||||
"Changing max_model_len from %s to %s. "
|
||||
"To enable context extension, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
|
||||
max_model_len_before,
|
||||
model_config.max_model_len,
|
||||
)
|
||||
else:
|
||||
# We need to re-verify max_model_len to avoid lengths
|
||||
# greater than position_embedding.
|
||||
model_config = vllm_config.model_config
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
if isinstance(model_config.hf_overrides, dict):
|
||||
# hf_overrides_kw
|
||||
max_model_len = model_config.hf_overrides.get(
|
||||
"max_model_len", vllm_config.model_config.max_model_len
|
||||
"max_model_len", model_config.max_model_len
|
||||
)
|
||||
else:
|
||||
# hf_overrides_fn
|
||||
# This might be overridden by sentence_bert_config.json.
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_model_len = model_config.max_model_len
|
||||
|
||||
# reset hf_text_config for recalculate_max_model_len.
|
||||
if hasattr(hf_text_config, "max_model_len"):
|
||||
@@ -196,13 +201,21 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
hf_text_config.max_position_embeddings = max_trained_positions
|
||||
hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
|
||||
|
||||
# Update the cached derived_max_model_len to enforce the limit
|
||||
model_config.model_arch_config.derived_max_model_len_and_key = (
|
||||
float(max_trained_positions),
|
||||
"max_position_embeddings",
|
||||
)
|
||||
|
||||
# The priority of sentence_bert_config.json is higher
|
||||
# than max_position_embeddings
|
||||
encoder_config = deepcopy(model_config.encoder_config)
|
||||
encoder_config.pop("max_seq_length", None)
|
||||
model_config.encoder_config = encoder_config
|
||||
|
||||
vllm_config.recalculate_max_model_len(max_model_len)
|
||||
model_config.max_model_len = model_config.get_and_verify_max_len(
|
||||
max_model_len
|
||||
)
|
||||
|
||||
|
||||
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
Reference in New Issue
Block a user