[Bugfix] Fix use_cascade_attention handling for Alibi-based models on vllm/v1 (#15211)
Signed-off-by: h-sugi <h.sugi@ieee.org> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -61,7 +61,7 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import enable_trace_function_call, init_logger
|
from vllm.logger import enable_trace_function_call, init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -2498,6 +2498,18 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Only relevant for models using ALiBi (e.g, MPT)
|
||||||
|
def check_use_alibi(model_config: ModelConfig) -> bool:
|
||||||
|
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon
|
||||||
|
or ("BloomForCausalLM" in getattr(model_config.hf_config,
|
||||||
|
"architectures", [])) # Bloom
|
||||||
|
or getattr(model_config.hf_text_config, "position_encoding_type",
|
||||||
|
"") == "alibi" # codellm_1b_alibi
|
||||||
|
or
|
||||||
|
(hasattr(model_config.hf_text_config, "attn_config") # MPT
|
||||||
|
and model_config.hf_text_config.attn_config.get("alibi", False)))
|
||||||
|
|
||||||
|
|
||||||
def sha256(input) -> int:
|
def sha256(input) -> int:
|
||||||
"""Hash any picklable Python object using SHA-256.
|
"""Hash any picklable Python object using SHA-256.
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
LayerBlockType, LazyLoader, cdiv,
|
LayerBlockType, LazyLoader, cdiv, check_use_alibi,
|
||||||
is_pin_memory_available)
|
is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
@@ -223,6 +223,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
|
|
||||||
|
# Only relevant for models using ALiBi (e.g, MPT)
|
||||||
|
self.use_alibi = check_use_alibi(model_config)
|
||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
self.inputs_embeds = torch.zeros(
|
||||||
(self.max_num_tokens, self.hidden_size),
|
(self.max_num_tokens, self.hidden_size),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@@ -689,7 +692,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
query_lens=num_scheduled_tokens,
|
query_lens=num_scheduled_tokens,
|
||||||
num_query_heads=self.num_query_heads,
|
num_query_heads=self.num_query_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
use_alibi=False, # FIXME
|
use_alibi=self.use_alibi,
|
||||||
use_sliding_window=self.window_size is not None,
|
use_sliding_window=self.window_size is not None,
|
||||||
num_sms=self.num_sms,
|
num_sms=self.num_sms,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user