Improve Mistral format checks. (#33253)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: juliendenize <julien.denize@mistral.ai> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -565,6 +565,7 @@ class ModelConfig:
|
||||
self.dtype,
|
||||
is_pooling_model=self.runner_type == "pooling",
|
||||
revision=self.revision,
|
||||
config_format=self.config_format,
|
||||
)
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
@@ -1844,9 +1845,10 @@ def _get_and_verify_dtype(
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
revision: str | None = None,
|
||||
config_format: ConfigFormat = "hf",
|
||||
) -> torch.dtype:
|
||||
config_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
|
||||
config, model_id, revision=revision
|
||||
config, model_id, revision=revision, config_format=config_format
|
||||
)
|
||||
model_type = config.model_type
|
||||
|
||||
|
||||
Reference in New Issue
Block a user