[Attention] MLA - Flashinfer Ragged Prefill (#20034)
This commit is contained in:
committed by
GitHub
parent
922f316441
commit
5b032352cc
@@ -14,13 +14,14 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout)
|
||||
PerLayerParameters,
|
||||
get_kv_cache_layout,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@@ -93,70 +94,6 @@ class FlashInferBackend(AttentionBackend):
|
||||
return stride_order
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
|
||||
"""
|
||||
Scan all attention layers and determine some hyperparameters
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, FlashInferImpl)
|
||||
|
||||
# Infer hyperparameters from the attention layer
|
||||
window_size = impl.sliding_window
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = impl.logits_soft_cap
|
||||
sm_scale = impl.scale
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(window_left,
|
||||
logits_soft_cap, sm_scale)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
- `sm_scale`
|
||||
|
||||
So this function asserts that all layers share the same values for these
|
||||
hyperparameters and returns the global values.
|
||||
"""
|
||||
|
||||
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
|
||||
@@ -336,7 +273,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
if self.global_hyperparameters is None:
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config))
|
||||
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
|
||||
Reference in New Issue
Block a user