[Model] Support is_causal HF config field for Qwen2 model (#10621)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -27,7 +27,7 @@ from vllm.transformers_utils.config import (
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
identity, print_warning_once, resolve_obj_by_qualname)
|
||||
print_warning_once, resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@@ -183,7 +183,7 @@ class ModelConfig:
|
||||
hf_overrides_fn = hf_overrides
|
||||
else:
|
||||
hf_overrides_kw = hf_overrides
|
||||
hf_overrides_fn = identity
|
||||
hf_overrides_fn = None
|
||||
|
||||
if rope_scaling is not None:
|
||||
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
|
||||
@@ -212,8 +212,15 @@ class ModelConfig:
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
|
||||
hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, config_format, **hf_overrides_kw)
|
||||
hf_config = hf_overrides_fn(hf_config)
|
||||
code_revision, config_format)
|
||||
|
||||
if hf_overrides_kw:
|
||||
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
||||
hf_config.update(hf_overrides_kw)
|
||||
if hf_overrides_fn:
|
||||
logger.info("Overriding HF config with %s", hf_overrides_fn)
|
||||
hf_config = hf_overrides_fn(hf_config)
|
||||
|
||||
self.hf_config = hf_config
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
|
||||
Reference in New Issue
Block a user