[Model] Systematic support for fp32 head, pooling models part (#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-09-09 22:29:50 +08:00
committed by GitHub
parent a55cf41a09
commit 19332c0479
14 changed files with 166 additions and 61 deletions

View File

@@ -745,7 +745,7 @@ class ModelConfig:
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
self.dtype: torch.dtype = _get_and_verify_dtype(
self.model,
self.hf_config,
self.dtype,
@@ -1751,6 +1751,32 @@ class ModelConfig:
# `llm as reranker` models defaults to not using pad_token.
return getattr(self.hf_config, "use_pad_token", True)
@property
def head_dtype(self) -> torch.dtype:
"""
"head" refers to the last Linear layer(s) of an LLM,
such as the lm_head in a generation model,
or the score or classifier in a classification model.
The default head_dtype based on runner_type.\n
- The pooling model defaults to using fp32 head,
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
- The generate model defaults to not using fp32 head,
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
"""
head_dtype = _get_head_dtype(config=self.hf_config,
dtype=self.dtype,
runner_type=self.runner_type)
if head_dtype not in current_platform.supported_dtypes:
logger.warning_once(
"The current platform does not support [%s] head dtype, "
"fallback to model dtype [%s].", head_dtype, self.dtype)
return self.dtype
logger.debug_once("head dtype: %s", head_dtype)
return head_dtype
def get_and_verify_max_len(self, max_model_len: int):
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
@@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
return torch_dtype
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
runner_type: str) -> torch.dtype:
head_dtype: Optional[Union[str,
torch.dtype]] = getattr(config, "head_dtype",
None)
if head_dtype == "model":
return dtype
elif isinstance(head_dtype, str):
head_dtype = head_dtype.lower()
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {head_dtype!r}")
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
elif isinstance(head_dtype, torch.dtype):
return head_dtype
elif head_dtype is None:
if torch.float32 not in current_platform.supported_dtypes:
return dtype
if runner_type == "pooling":
return torch.float32
return dtype
else:
raise ValueError(f"Unknown dtype: {head_dtype}")
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
tokenizer_config: Optional[dict],