[VLM] Support caching in merged multi-modal processor (#11396)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,25 +1,31 @@
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast
|
||||
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
|
||||
def get_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
trust_remote_code: bool = False,
|
||||
processor_cls: type[ProcessorMixin] = ProcessorMixin,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Load a processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
processor_factory = (AutoProcessor
|
||||
if processor_cls == ProcessorMixin else processor_cls)
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
processor = processor_factory.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
|
||||
Reference in New Issue
Block a user