[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -745,7 +745,7 @@ class ModelConfig:
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self.dtype = _get_and_verify_dtype(
|
||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
self.hf_config,
|
||||
self.dtype,
|
||||
@@ -1751,6 +1751,32 @@ class ModelConfig:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
return getattr(self.hf_config, "use_pad_token", True)
|
||||
|
||||
@property
|
||||
def head_dtype(self) -> torch.dtype:
|
||||
"""
|
||||
"head" refers to the last Linear layer(s) of an LLM,
|
||||
such as the lm_head in a generation model,
|
||||
or the score or classifier in a classification model.
|
||||
|
||||
The default head_dtype based on runner_type.\n
|
||||
- The pooling model defaults to using fp32 head,
|
||||
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
|
||||
- The generate model defaults to not using fp32 head,
|
||||
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
|
||||
"""
|
||||
head_dtype = _get_head_dtype(config=self.hf_config,
|
||||
dtype=self.dtype,
|
||||
runner_type=self.runner_type)
|
||||
|
||||
if head_dtype not in current_platform.supported_dtypes:
|
||||
logger.warning_once(
|
||||
"The current platform does not support [%s] head dtype, "
|
||||
"fallback to model dtype [%s].", head_dtype, self.dtype)
|
||||
return self.dtype
|
||||
|
||||
logger.debug_once("head dtype: %s", head_dtype)
|
||||
return head_dtype
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
# Consider max_model_len in tokenizer_config only when
|
||||
# pooling models use absolute position_embedding.
|
||||
@@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
|
||||
runner_type: str) -> torch.dtype:
|
||||
head_dtype: Optional[Union[str,
|
||||
torch.dtype]] = getattr(config, "head_dtype",
|
||||
None)
|
||||
|
||||
if head_dtype == "model":
|
||||
return dtype
|
||||
elif isinstance(head_dtype, str):
|
||||
head_dtype = head_dtype.lower()
|
||||
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype!r}")
|
||||
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
|
||||
elif isinstance(head_dtype, torch.dtype):
|
||||
return head_dtype
|
||||
elif head_dtype is None:
|
||||
if torch.float32 not in current_platform.supported_dtypes:
|
||||
return dtype
|
||||
if runner_type == "pooling":
|
||||
return torch.float32
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype}")
|
||||
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
tokenizer_config: Optional[dict],
|
||||
|
||||
Reference in New Issue
Block a user