[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -8,24 +8,10 @@ from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.xformers import XFormersBackend
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
# String name of register which may be set in order to
|
||||
# force auto-selection of attention backend by Attention
|
||||
# wrapper
|
||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad)
|
||||
|
||||
|
||||
class QKVInputs(NamedTuple):
|
||||
|
||||
Reference in New Issue
Block a user