[V0 deprecation] Remove no longer used get_metadata_cls (#28370)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -4,24 +4,21 @@
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from numbers import Number
|
||||
from typing import Any, NamedTuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._prims_common import TensorLikeType
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils import (
|
||||
STR_BACKEND_ENV_VAR,
|
||||
STR_FLASH_ATTN_VAL,
|
||||
STR_XFORMERS_ATTN_VAL,
|
||||
)
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
|
||||
@@ -512,50 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
|
||||
)
|
||||
|
||||
|
||||
def make_backend(backend_name: str) -> AttentionBackend:
|
||||
"""
|
||||
Construct the backend instance determined by the backend_name string
|
||||
argument.
|
||||
|
||||
Note: at time of writing the Attention wrapper automatically selects
|
||||
its own backend for Attention.forward(); so the backend instance which
|
||||
you generate with this function is not meant to be used for *running*
|
||||
inference, but rather for generating compatible metadata structures
|
||||
using backend.make_metadata()
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
* Backend instance
|
||||
"""
|
||||
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||
from vllm.v1.attention.backends.xformers import XFormersAttentionBackend
|
||||
|
||||
return XFormersAttentionBackend()
|
||||
if backend_name == STR_FLASH_ATTN_VAL:
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
return FlashAttentionBackend()
|
||||
if backend_name == "TRITON_ATTN":
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
|
||||
|
||||
return TritonAttentionBackend()
|
||||
if backend_name == "FLEX_ATTENTION":
|
||||
from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend
|
||||
|
||||
return FlexAttentionBackend()
|
||||
if backend_name == "TORCH_SDPA":
|
||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
|
||||
|
||||
return TorchSDPABackend()
|
||||
if backend_name == "FLASHINFER":
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||
|
||||
return FlashInferBackend()
|
||||
|
||||
raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test")
|
||||
|
||||
|
||||
def make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
@@ -877,197 +830,6 @@ def make_block_tables_slot_mapping(
|
||||
return (block_tables_tensor, slot_mapping_list, max_block_idx)
|
||||
|
||||
|
||||
def make_test_metadata(
|
||||
attn_backend: _Backend,
|
||||
is_prompt: bool,
|
||||
seq_lens: list[int] | None,
|
||||
decoder_test_params: PhaseTestParameters | None,
|
||||
device: torch.device | str,
|
||||
encoder_test_params: PhaseTestParameters | None = None,
|
||||
cross_test_params: PhaseTestParameters | None = None,
|
||||
) -> AttentionMetadata:
|
||||
"""
|
||||
Construct fake attention metadata for a given test phase
|
||||
(prefill-phase or decode-phase).
|
||||
|
||||
encoder_test_params and cross_test_params arguments allow encoder
|
||||
attention and enc/dec cross-attention (respectively) to use distinct
|
||||
metadata values from decoder self-attention (decoder_test_params.)
|
||||
|
||||
if encoder_test_params and cross_test_params are None, the attention
|
||||
metadata will support decoder-only scenario.
|
||||
|
||||
Assumptions:
|
||||
|
||||
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend_name: Backend for sourcing attention kernels
|
||||
* is_prompt: prefill if True, o/w decode
|
||||
* seq_lens: list of token counts for each sequence
|
||||
* decoder_test_params: decoder self-attention test params;
|
||||
this function requires
|
||||
kv_mmap (memory mapping) field
|
||||
* device: CPU or CUDA device
|
||||
* encoder_test_params: encoder attention test params;
|
||||
this function requires encoder query
|
||||
sequence lengths field. If None,
|
||||
encoder query sequence lengths are
|
||||
treated as None
|
||||
* cross_test_params: enc/dec cross-attention test params;
|
||||
this function requires kv_mmap field.
|
||||
If None, KV cache memory map data
|
||||
structures are treated as None
|
||||
|
||||
Return:
|
||||
|
||||
* AttentionMetadata structure
|
||||
"""
|
||||
|
||||
# Decoder self-attention memory mapping
|
||||
# decoder_test_params is None signals encoder-only
|
||||
# scenario, so kv_mmap is None
|
||||
kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap
|
||||
|
||||
# This function constructs metadata assuming no chunked prefill,
|
||||
# i.e. 100% prefill tokens or 100% decode tokens
|
||||
#
|
||||
# - If is_prompt, num_prefills_or_decodes is the number of prefills
|
||||
# and num_prefill_or_decode_tokens is the number of prefill tokens
|
||||
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
|
||||
# and num_prefill_or_decode_tokens is the number of decode tokens
|
||||
#
|
||||
# seq_lens is None signals encoder-only
|
||||
# scenario, in which case num_prefills_or_decodes and
|
||||
# num_prefill_or_decode_tokens are unused
|
||||
num_prefills_or_decodes = None if seq_lens is None else len(seq_lens)
|
||||
|
||||
num_prefill_or_decode_tokens = (
|
||||
None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens))
|
||||
)
|
||||
|
||||
# Seems for non-prefix-caching scenarios context_lens
|
||||
# is never needed
|
||||
context_lens = None
|
||||
|
||||
if encoder_test_params is None:
|
||||
encoder_seq_lens = None
|
||||
num_encoder_tokens = None
|
||||
else:
|
||||
# Encoder/decoder or encoder-only models only:
|
||||
# * Extract encoder input sequence lengths
|
||||
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
||||
num_encoder_tokens = (
|
||||
None if encoder_seq_lens is None else (sum(encoder_seq_lens))
|
||||
)
|
||||
|
||||
# For encoder/decoder or encoder-only models only, extract *cross-attention*
|
||||
# slot_mapping and block table (kv_mmap)
|
||||
cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
|
||||
|
||||
attn_backend_obj = make_backend(attn_backend.name)
|
||||
|
||||
if is_prompt:
|
||||
# Prefill-phase scenario
|
||||
|
||||
num_prefills = num_prefills_or_decodes
|
||||
num_prefill_tokens = num_prefill_or_decode_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
(
|
||||
seq_lens_tensor,
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
seq_start_loc,
|
||||
encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(
|
||||
seq_lens, context_lens, encoder_seq_lens, device=device
|
||||
)
|
||||
return attn_backend_obj.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
||||
max_decode_seq_len=0,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
|
||||
use_cuda_graph=False,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(
|
||||
None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
|
||||
),
|
||||
cross_block_tables=(
|
||||
None if cross_kv_mmap is None else cross_kv_mmap.block_tables
|
||||
),
|
||||
)
|
||||
|
||||
else: # not is_prompt
|
||||
# Decode-phase scenario
|
||||
|
||||
assert kv_mmap is not None
|
||||
assert num_prefill_or_decode_tokens is not None
|
||||
assert seq_lens is not None
|
||||
|
||||
num_prefills = 0
|
||||
num_prefill_tokens = 0
|
||||
num_decode_tokens = num_prefill_or_decode_tokens
|
||||
|
||||
(
|
||||
seq_lens_tensor,
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
seq_start_loc,
|
||||
encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(
|
||||
seq_lens, context_lens, encoder_seq_lens, device=device
|
||||
)
|
||||
|
||||
return attn_backend_obj.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=kv_mmap.slot_mapping,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=max(seq_lens),
|
||||
max_decode_query_len=1,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=kv_mmap.block_tables,
|
||||
use_cuda_graph=False,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(
|
||||
None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
|
||||
),
|
||||
cross_block_tables=(
|
||||
None if cross_kv_mmap is None else cross_kv_mmap.block_tables
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def assert_actual_matches_ideal(
|
||||
test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str
|
||||
) -> None:
|
||||
@@ -1308,7 +1070,7 @@ def opcheck(
|
||||
raise_exception: bool = True,
|
||||
cond: bool = True,
|
||||
) -> dict[str, str]:
|
||||
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
||||
with patch("torch.allclose", new=fp8_allclose):
|
||||
return (
|
||||
torch.library.opcheck(
|
||||
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
||||
|
||||
Reference in New Issue
Block a user