[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
130
vllm/utils.py
130
vllm/utils.py
@@ -27,10 +27,93 @@ from typing_extensions import ParamSpec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Exception strings for non-implemented encoder/decoder scenarios
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||
"Sliding window attention for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
||||
"Prefix caching for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
|
||||
"Chunked prefill for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
|
||||
"Models with logits_soft_cap "
|
||||
"require FlashInfer backend, which is "
|
||||
"currently not supported for encoder/decoder "
|
||||
"models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
|
||||
"supported with encoder/decoder "
|
||||
"models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
|
||||
"currently supported with "
|
||||
"encoder/decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
|
||||
"supported with encoder/decoder "
|
||||
"models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
# Efficiently import all enc/dec error strings
|
||||
# rather than having to import all of the above
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||
"STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
|
||||
"STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
"STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
|
||||
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
|
||||
"STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
|
||||
"STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
|
||||
"STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
|
||||
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
|
||||
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
|
||||
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
|
||||
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
|
||||
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
|
||||
}
|
||||
|
||||
# Constants related to forcing the attention backend selection
|
||||
|
||||
# String name of register which may be set in order to
|
||||
# force auto-selection of attention backend by Attention
|
||||
# wrapper
|
||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
@@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
def is_encoder_decoder_model_config(model_config) -> bool:
|
||||
'''
|
||||
Extract the HF encoder/decoder model flag from the ModelConfig instance.
|
||||
Return False if model_config is None.
|
||||
'''
|
||||
return model_config is not None and \
|
||||
getattr(model_config.hf_config,
|
||||
"is_encoder_decoder",
|
||||
False)
|
||||
|
||||
|
||||
def is_embedding_model_config(model_config) -> bool:
|
||||
'''
|
||||
Extract the embedding model flag from the ModelConfig instance.
|
||||
Return False if model_config is None.
|
||||
'''
|
||||
return model_config is not None and \
|
||||
model_config.embedding_mode
|
||||
|
||||
|
||||
def build_explicit_enc_dec_prompt(
|
||||
encoder_prompt: SingletonPromptInputs,
|
||||
decoder_prompt: SingletonPromptInputs,
|
||||
) -> ExplicitEncoderDecoderPrompt:
|
||||
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
|
||||
decoder_prompt=decoder_prompt)
|
||||
|
||||
|
||||
def zip_enc_dec_prompt_lists(
|
||||
enc_prompt_list: List[SingletonPromptInputs],
|
||||
dec_prompt_list: List[SingletonPromptInputs],
|
||||
) -> List[ExplicitEncoderDecoderPrompt]:
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
|
||||
for (encoder_prompt,
|
||||
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
|
||||
]
|
||||
|
||||
|
||||
def to_enc_dec_tuple_list(
|
||||
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
|
||||
) -> List[Tuple[PromptInputs, PromptInputs]]:
|
||||
return [(enc_dec_prompt['encoder_prompt'],
|
||||
enc_dec_prompt['decoder_prompt'])
|
||||
for enc_dec_prompt in enc_dec_prompts]
|
||||
|
||||
Reference in New Issue
Block a user