[Attention] MLA - Flashinfer Ragged Prefill (#20034)
This commit is contained in:
committed by
GitHub
parent
922f316441
commit
5b032352cc
@@ -4,14 +4,17 @@ import abc
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionImpl
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
@@ -98,39 +101,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
return False
|
||||
|
||||
|
||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
static_forward_context):
|
||||
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} ")
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg +
|
||||
"cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg +
|
||||
"is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[
|
||||
target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg +
|
||||
f"must be the same type as the current layer ({expected}).")
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Override with format specified by the user.
|
||||
@@ -144,6 +114,71 @@ def get_kv_cache_layout():
|
||||
return cache_layout
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
vllm_config: VllmConfig,
|
||||
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
|
||||
"""
|
||||
Scan all attention layers and determine some hyperparameters
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, cls_)
|
||||
|
||||
# Infer hyperparameters from the attention layer
|
||||
window_size = getattr(impl, "sliding_window", None)
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
|
||||
sm_scale = impl.scale
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(window_left,
|
||||
logits_soft_cap, sm_scale)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
- `sm_scale`
|
||||
|
||||
So this function asserts that all layers share the same values for these
|
||||
hyperparameters and returns the global values.
|
||||
"""
|
||||
|
||||
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
|
||||
Reference in New Issue
Block a user