[Misc] Replace os environ to monkeypatch in test suite (#14516)
Signed-off-by: sibi <85477603+t-sibiraj@users.noreply.github.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -5,13 +5,12 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.openvino import OpenVinoPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -25,87 +24,111 @@ def clear_cache():
|
||||
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
||||
def test_env(name: str, use_v1: bool, device: str, monkeypatch):
|
||||
def test_env(
|
||||
name: str,
|
||||
use_v1: bool,
|
||||
device: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that the attention selector can be set via environment variable.
|
||||
Note that we do not test FlashAttn because it is the default backend.
|
||||
"""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
override_backend_env_variable(monkeypatch, name)
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, name)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||
assert backend.get_name() == EXPECTED
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
OpenVinoPlatform()), patch.dict('sys.modules',
|
||||
{'openvino': Mock()}):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.get_name() == "OPENVINO"
|
||||
else:
|
||||
if name in ["XFORMERS", "FLASHINFER"]:
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||
assert backend.get_name() == EXPECTED
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
OpenVinoPlatform()), patch.dict('sys.modules',
|
||||
{'openvino': Mock()}):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
assert backend.get_name() == "OPENVINO"
|
||||
else:
|
||||
if name in ["XFORMERS", "FLASHINFER"]:
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16,
|
||||
torch.float16, 16, False)
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == EXPECTED
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch):
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FlashAttn validation."""
|
||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||
# get_attn_backend
|
||||
|
||||
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
|
||||
|
||||
# Unsupported CUDA arch
|
||||
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
||||
# Unsupported CUDA arch
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda:
|
||||
(7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported data type
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
# Reset the monkeypatch for subsequent tests
|
||||
monkeypatch.undo()
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
# Unsupported data type
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported block size
|
||||
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
# Unsupported kv cache data type
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||
# Unsupported block size
|
||||
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
import sys
|
||||
original_module = sys.modules.get('vllm_flash_attn')
|
||||
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
# Restore the original module if it existed
|
||||
if original_module is not None:
|
||||
monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
|
||||
original_module)
|
||||
else:
|
||||
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
|
||||
|
||||
# Attention-free models should bypass env and use PlaceholderAttention
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Attention-free models should bypass env and use PlaceholderAttention
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
def test_invalid_env(use_v1: bool, monkeypatch):
|
||||
"""Ignore the invalid env variable if it is set."""
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
with monkeypatch.context() as m, patch(
|
||||
"vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||
|
||||
# Test with head size 32
|
||||
backend = get_attn_backend(32, torch.float16, None, 16, False)
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
|
||||
assert backend.get_name() == EXPECTED
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -11,36 +9,38 @@ from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_dequantize_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
|
||||
split_k_iters = 0
|
||||
thx = 0
|
||||
thy = 0
|
||||
opcheck(torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
|
||||
split_k_iters = 0
|
||||
thx = 0
|
||||
thy = 0
|
||||
opcheck(torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not working; needs investigation.")
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_gemm_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.randint(-2000000000,
|
||||
2000000000, (64, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm,
|
||||
(input, qweight, qzeros, scales, split_k_iters))
|
||||
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.randint(-2000000000,
|
||||
2000000000, (64, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm,
|
||||
(input, qweight, qzeros, scales, split_k_iters))
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -17,15 +15,19 @@ def clear_cache():
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
def test_selector(monkeypatch):
|
||||
"""Test that the attention selector for ROCm.
|
||||
"""
|
||||
override_backend_env_variable(monkeypatch, "ROCM_FLASH")
|
||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
|
||||
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform",
|
||||
RocmPlatform())
|
||||
|
||||
# Test standard ROCm attention
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert (backend.get_name() == "ROCM_FLASH"
|
||||
or backend.get_name() == "ROCM_ATTN_VLLM_V1")
|
||||
|
||||
# mla test for deepseek related
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
|
||||
@@ -36,12 +36,12 @@ ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
|
||||
|
||||
class QKVInputs(NamedTuple):
|
||||
'''
|
||||
Data structure for representing unpacked attention inputs,
|
||||
Data structure for representing unpacked attention inputs,
|
||||
query/key/values and their sequence lengths.
|
||||
|
||||
Attributes:
|
||||
|
||||
* {query,key,value}: unpacked (batch_size x padded_seq_len x
|
||||
* {query,key,value}: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) attention inputs
|
||||
* q_seq_lens: query sequence lengths list
|
||||
* kv_seq_lens: shared key/value sequence lengths list
|
||||
@@ -56,14 +56,14 @@ class QKVInputs(NamedTuple):
|
||||
|
||||
class QKVO(NamedTuple):
|
||||
'''
|
||||
Data structure for representing unpacked attention inputs,
|
||||
Data structure for representing unpacked attention inputs,
|
||||
alongside unpacked known-correct attention output
|
||||
|
||||
Attributes:
|
||||
|
||||
* qkv: unpacked (batch_size x padded_seq_len x
|
||||
* qkv: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) attention inputs
|
||||
* ideal_output: unpacked (batch_size x padded_seq_len x
|
||||
* ideal_output: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) known-correct attention output
|
||||
'''
|
||||
|
||||
@@ -77,7 +77,7 @@ class PackedQKVInputs(NamedTuple):
|
||||
|
||||
Attributes:
|
||||
|
||||
* {query,key,value}: packed (number_of_tokens x num_heads
|
||||
* {query,key,value}: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs
|
||||
* q_start_loc_list: list of query start locations within packed tensor
|
||||
* kv_start_loc_list: shared list of key/value start locations within
|
||||
@@ -97,14 +97,14 @@ class PackedQKVInputs(NamedTuple):
|
||||
|
||||
class PackedQKVO(NamedTuple):
|
||||
'''
|
||||
Data structure for representing packed attention inputs,
|
||||
Data structure for representing packed attention inputs,
|
||||
alongside packed known-correct attention output
|
||||
|
||||
Attributes:
|
||||
|
||||
* packed_qkv: packed (number_of_tokens x num_heads
|
||||
* packed_qkv: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs
|
||||
* ideal_output: packed (number_of_tokens x num_heads
|
||||
* ideal_output: packed (number_of_tokens x num_heads
|
||||
x head_size) known-correct attention output
|
||||
'''
|
||||
|
||||
@@ -134,7 +134,7 @@ class PhaseTestParameters(NamedTuple):
|
||||
|
||||
Attributes:
|
||||
|
||||
* packed_qkvo: packed (number_of_tokens x num_heads
|
||||
* packed_qkvo: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs & known-correct
|
||||
output
|
||||
* kv_mmap: KV cache memory mapping, specific to this test phase &
|
||||
@@ -195,7 +195,7 @@ def make_causal_mask(
|
||||
Create a q_max_seq_len x kv_max_seq_len causal mask
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
* q_max_seq_len: query max seq len
|
||||
* kv_max_seq_len: key/value max seq len
|
||||
|
||||
@@ -320,9 +320,9 @@ def make_qkv(
|
||||
* max_kv_seq_len: max key/value seq len
|
||||
* num_heads
|
||||
* head_size
|
||||
* is_encoder_decoder_attn: if True, query seqlen may differ from
|
||||
key/value seqlen (as is often the case for cross-attention);
|
||||
o/w, query/key/value seqlens match at each batch index
|
||||
* is_encoder_decoder_attn: if True, query seqlen may differ from
|
||||
key/value seqlen (as is often the case for cross-attention);
|
||||
o/w, query/key/value seqlens match at each batch index
|
||||
(max_kv_seq_len is unused)
|
||||
* force_kv_seq_lens: if not None, overrides kv sequence lengths
|
||||
* attn_type: encoder, decoder self, or enc/dec cross attention
|
||||
@@ -469,7 +469,7 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
|
||||
Individually pack each of Q, K and V, each with dimensions batch_size x
|
||||
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
|
||||
num_heads x head_size tensors.
|
||||
|
||||
|
||||
For Q, number_of_tokens = sum(q_seq_lens).
|
||||
|
||||
For K and V, number_of_tokens = sum(kv_seq_lens)
|
||||
@@ -619,9 +619,9 @@ def make_kv_cache(num_blocks: int,
|
||||
Returns:
|
||||
|
||||
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
||||
* for backend 'XFORMERS'
|
||||
* for backend 'XFORMERS'
|
||||
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
||||
* for backend 'FLASH_ATTN'
|
||||
* for backend 'FLASH_ATTN'
|
||||
'''
|
||||
if backend == 'XFORMERS':
|
||||
kv_cache = torch.rand(
|
||||
@@ -662,20 +662,20 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
|
||||
Context:
|
||||
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
|
||||
{K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
|
||||
for all N prompts (N tokens total); the resultant sequence lengths
|
||||
for all N prompts (N tokens total); the resultant sequence lengths
|
||||
after decode would be {K_i + 1 for i \\in [0,N)}
|
||||
* The test you want to do requires (1) having the prefill slot mapping
|
||||
for all tokens present during prefill, the number of which is
|
||||
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
|
||||
* The test you want to do requires (1) having the prefill slot mapping
|
||||
for all tokens present during prefill, the number of which is
|
||||
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
|
||||
decoded tokens
|
||||
|
||||
This function consumes a single 1D slot mapping, which is the
|
||||
|
||||
This function consumes a single 1D slot mapping, which is the
|
||||
concatenation of N slot mappings each of length K_i + 1 (corresponding
|
||||
to the sequence lengths after decode), with a total length of
|
||||
P = \\sum_i{K_i + 1} = M + N
|
||||
|
||||
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
|
||||
from each of the N subsequences in the slot mapping (i.e. omitting the
|
||||
from each of the N subsequences in the slot mapping (i.e. omitting the
|
||||
decoded token's mapping.)
|
||||
|
||||
The N excised entries are appended to obtain the decode-phase slot mapping
|
||||
@@ -684,15 +684,15 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
|
||||
|
||||
* slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
|
||||
post-decode sequences
|
||||
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
|
||||
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
|
||||
description above)
|
||||
* device: cuda, cpu, etc.
|
||||
|
||||
Returns:
|
||||
|
||||
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
|
||||
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
|
||||
reflecting all N prefill prompts
|
||||
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
|
||||
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
|
||||
all N decoded tokens
|
||||
'''
|
||||
|
||||
@@ -725,7 +725,7 @@ def make_block_tables_slot_mapping(
|
||||
|
||||
Then the minimum KV cache size in blocks is
|
||||
|
||||
total_cache_blocks = sum(num_blocks for all seqs)
|
||||
total_cache_blocks = sum(num_blocks for all seqs)
|
||||
|
||||
Then, the blocktable mapping counts downward from
|
||||
|
||||
@@ -734,7 +734,7 @@ def make_block_tables_slot_mapping(
|
||||
to
|
||||
|
||||
block_base_addr
|
||||
|
||||
|
||||
|
||||
The constructed block-tables and slot-mapping are sized to the
|
||||
lengths of the sequences in their entirety (as reflected by seq_lens),
|
||||
@@ -749,7 +749,7 @@ def make_block_tables_slot_mapping(
|
||||
|
||||
Return:
|
||||
|
||||
* block_tables_tensor: block table for sequence
|
||||
* block_tables_tensor: block table for sequence
|
||||
* slot_mapping_list: slot mapping for sequence
|
||||
* max_block_idx: the highest block address within this block table
|
||||
'''
|
||||
@@ -807,7 +807,7 @@ def make_test_metadata(
|
||||
encoder_test_params and cross_test_params arguments allow encoder
|
||||
attention and enc/dec cross-attention (respectively) to use distinct
|
||||
metadata values from decoder self-attention (decoder_test_params.)
|
||||
|
||||
|
||||
if encoder_test_params and cross_test_params are None, the attention
|
||||
metadata will support decoder-only scenario.
|
||||
|
||||
@@ -820,7 +820,7 @@ def make_test_metadata(
|
||||
* attn_backend_name: Backend for sourcing attention kernels
|
||||
* is_prompt: prefill if True, o/w decode
|
||||
* seq_lens: list of token counts for each sequence
|
||||
* decoder_test_params: decoder self-attention test params;
|
||||
* decoder_test_params: decoder self-attention test params;
|
||||
this function requires
|
||||
kv_mmap (memory mapping) field
|
||||
* device: CPU or CUDA device
|
||||
|
||||
Reference in New Issue
Block a user