[Model] Support is_causal HF config field for Qwen2 model (#10621)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-25 17:51:20 +08:00
committed by GitHub
parent 05d1f8c9c6
commit ed46f14321
5 changed files with 51 additions and 13 deletions

View File

@@ -27,7 +27,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once, resolve_obj_by_qualname)
print_warning_once, resolve_obj_by_qualname)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@@ -183,7 +183,7 @@ class ModelConfig:
hf_overrides_fn = hf_overrides
else:
hf_overrides_kw = hf_overrides
hf_overrides_fn = identity
hf_overrides_fn = None
if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
@@ -212,8 +212,15 @@ class ModelConfig:
self.skip_tokenizer_init = skip_tokenizer_init
hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format, **hf_overrides_kw)
hf_config = hf_overrides_fn(hf_config)
code_revision, config_format)
if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.info("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)