[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)

Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
afeldman-nm
2024-08-06 16:51:47 -04:00
committed by GitHub
parent 660470e5a3
commit fd95e026e0
33 changed files with 3957 additions and 333 deletions

View File

@@ -27,10 +27,93 @@ from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
SingletonPromptInputs)
from vllm.logger import enable_trace_function_call, init_logger
logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
"Chunked prefill for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
"Models with logits_soft_cap "
"require FlashInfer backend, which is "
"currently not supported for encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
"supported with encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
"currently supported with "
"encoder/decoder models.")
STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
"supported with encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"decoder models.")
# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
"STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
"STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
"STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
"STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
"STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}
# Constants related to forcing the attention backend selection
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
@@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
def is_encoder_decoder_model_config(model_config) -> bool:
'''
Extract the HF encoder/decoder model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
getattr(model_config.hf_config,
"is_encoder_decoder",
False)
def is_embedding_model_config(model_config) -> bool:
'''
Extract the embedding model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
model_config.embedding_mode
def build_explicit_enc_dec_prompt(
encoder_prompt: SingletonPromptInputs,
decoder_prompt: SingletonPromptInputs,
) -> ExplicitEncoderDecoderPrompt:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompt_lists(
enc_prompt_list: List[SingletonPromptInputs],
dec_prompt_list: List[SingletonPromptInputs],
) -> List[ExplicitEncoderDecoderPrompt]:
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt,
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
) -> List[Tuple[PromptInputs, PromptInputs]]:
return [(enc_dec_prompt['encoder_prompt'],
enc_dec_prompt['decoder_prompt'])
for enc_dec_prompt in enc_dec_prompts]