[Misc] Replace os environ to monkeypatch in test suite (#14516)
Signed-off-by: sibi <85477603+t-sibiraj@users.noreply.github.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -36,12 +36,12 @@ ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
|
||||
|
||||
class QKVInputs(NamedTuple):
|
||||
'''
|
||||
Data structure for representing unpacked attention inputs,
|
||||
Data structure for representing unpacked attention inputs,
|
||||
query/key/values and their sequence lengths.
|
||||
|
||||
Attributes:
|
||||
|
||||
* {query,key,value}: unpacked (batch_size x padded_seq_len x
|
||||
* {query,key,value}: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) attention inputs
|
||||
* q_seq_lens: query sequence lengths list
|
||||
* kv_seq_lens: shared key/value sequence lengths list
|
||||
@@ -56,14 +56,14 @@ class QKVInputs(NamedTuple):
|
||||
|
||||
class QKVO(NamedTuple):
|
||||
'''
|
||||
Data structure for representing unpacked attention inputs,
|
||||
Data structure for representing unpacked attention inputs,
|
||||
alongside unpacked known-correct attention output
|
||||
|
||||
Attributes:
|
||||
|
||||
* qkv: unpacked (batch_size x padded_seq_len x
|
||||
* qkv: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) attention inputs
|
||||
* ideal_output: unpacked (batch_size x padded_seq_len x
|
||||
* ideal_output: unpacked (batch_size x padded_seq_len x
|
||||
num_heads x head_size) known-correct attention output
|
||||
'''
|
||||
|
||||
@@ -77,7 +77,7 @@ class PackedQKVInputs(NamedTuple):
|
||||
|
||||
Attributes:
|
||||
|
||||
* {query,key,value}: packed (number_of_tokens x num_heads
|
||||
* {query,key,value}: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs
|
||||
* q_start_loc_list: list of query start locations within packed tensor
|
||||
* kv_start_loc_list: shared list of key/value start locations within
|
||||
@@ -97,14 +97,14 @@ class PackedQKVInputs(NamedTuple):
|
||||
|
||||
class PackedQKVO(NamedTuple):
|
||||
'''
|
||||
Data structure for representing packed attention inputs,
|
||||
Data structure for representing packed attention inputs,
|
||||
alongside packed known-correct attention output
|
||||
|
||||
Attributes:
|
||||
|
||||
* packed_qkv: packed (number_of_tokens x num_heads
|
||||
* packed_qkv: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs
|
||||
* ideal_output: packed (number_of_tokens x num_heads
|
||||
* ideal_output: packed (number_of_tokens x num_heads
|
||||
x head_size) known-correct attention output
|
||||
'''
|
||||
|
||||
@@ -134,7 +134,7 @@ class PhaseTestParameters(NamedTuple):
|
||||
|
||||
Attributes:
|
||||
|
||||
* packed_qkvo: packed (number_of_tokens x num_heads
|
||||
* packed_qkvo: packed (number_of_tokens x num_heads
|
||||
x head_size) attention inputs & known-correct
|
||||
output
|
||||
* kv_mmap: KV cache memory mapping, specific to this test phase &
|
||||
@@ -195,7 +195,7 @@ def make_causal_mask(
|
||||
Create a q_max_seq_len x kv_max_seq_len causal mask
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
* q_max_seq_len: query max seq len
|
||||
* kv_max_seq_len: key/value max seq len
|
||||
|
||||
@@ -320,9 +320,9 @@ def make_qkv(
|
||||
* max_kv_seq_len: max key/value seq len
|
||||
* num_heads
|
||||
* head_size
|
||||
* is_encoder_decoder_attn: if True, query seqlen may differ from
|
||||
key/value seqlen (as is often the case for cross-attention);
|
||||
o/w, query/key/value seqlens match at each batch index
|
||||
* is_encoder_decoder_attn: if True, query seqlen may differ from
|
||||
key/value seqlen (as is often the case for cross-attention);
|
||||
o/w, query/key/value seqlens match at each batch index
|
||||
(max_kv_seq_len is unused)
|
||||
* force_kv_seq_lens: if not None, overrides kv sequence lengths
|
||||
* attn_type: encoder, decoder self, or enc/dec cross attention
|
||||
@@ -469,7 +469,7 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
|
||||
Individually pack each of Q, K and V, each with dimensions batch_size x
|
||||
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
|
||||
num_heads x head_size tensors.
|
||||
|
||||
|
||||
For Q, number_of_tokens = sum(q_seq_lens).
|
||||
|
||||
For K and V, number_of_tokens = sum(kv_seq_lens)
|
||||
@@ -619,9 +619,9 @@ def make_kv_cache(num_blocks: int,
|
||||
Returns:
|
||||
|
||||
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
||||
* for backend 'XFORMERS'
|
||||
* for backend 'XFORMERS'
|
||||
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
||||
* for backend 'FLASH_ATTN'
|
||||
* for backend 'FLASH_ATTN'
|
||||
'''
|
||||
if backend == 'XFORMERS':
|
||||
kv_cache = torch.rand(
|
||||
@@ -662,20 +662,20 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
|
||||
Context:
|
||||
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
|
||||
{K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
|
||||
for all N prompts (N tokens total); the resultant sequence lengths
|
||||
for all N prompts (N tokens total); the resultant sequence lengths
|
||||
after decode would be {K_i + 1 for i \\in [0,N)}
|
||||
* The test you want to do requires (1) having the prefill slot mapping
|
||||
for all tokens present during prefill, the number of which is
|
||||
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
|
||||
* The test you want to do requires (1) having the prefill slot mapping
|
||||
for all tokens present during prefill, the number of which is
|
||||
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
|
||||
decoded tokens
|
||||
|
||||
This function consumes a single 1D slot mapping, which is the
|
||||
|
||||
This function consumes a single 1D slot mapping, which is the
|
||||
concatenation of N slot mappings each of length K_i + 1 (corresponding
|
||||
to the sequence lengths after decode), with a total length of
|
||||
P = \\sum_i{K_i + 1} = M + N
|
||||
|
||||
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
|
||||
from each of the N subsequences in the slot mapping (i.e. omitting the
|
||||
from each of the N subsequences in the slot mapping (i.e. omitting the
|
||||
decoded token's mapping.)
|
||||
|
||||
The N excised entries are appended to obtain the decode-phase slot mapping
|
||||
@@ -684,15 +684,15 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
|
||||
|
||||
* slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
|
||||
post-decode sequences
|
||||
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
|
||||
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
|
||||
description above)
|
||||
* device: cuda, cpu, etc.
|
||||
|
||||
Returns:
|
||||
|
||||
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
|
||||
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
|
||||
reflecting all N prefill prompts
|
||||
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
|
||||
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
|
||||
all N decoded tokens
|
||||
'''
|
||||
|
||||
@@ -725,7 +725,7 @@ def make_block_tables_slot_mapping(
|
||||
|
||||
Then the minimum KV cache size in blocks is
|
||||
|
||||
total_cache_blocks = sum(num_blocks for all seqs)
|
||||
total_cache_blocks = sum(num_blocks for all seqs)
|
||||
|
||||
Then, the blocktable mapping counts downward from
|
||||
|
||||
@@ -734,7 +734,7 @@ def make_block_tables_slot_mapping(
|
||||
to
|
||||
|
||||
block_base_addr
|
||||
|
||||
|
||||
|
||||
The constructed block-tables and slot-mapping are sized to the
|
||||
lengths of the sequences in their entirety (as reflected by seq_lens),
|
||||
@@ -749,7 +749,7 @@ def make_block_tables_slot_mapping(
|
||||
|
||||
Return:
|
||||
|
||||
* block_tables_tensor: block table for sequence
|
||||
* block_tables_tensor: block table for sequence
|
||||
* slot_mapping_list: slot mapping for sequence
|
||||
* max_block_idx: the highest block address within this block table
|
||||
'''
|
||||
@@ -807,7 +807,7 @@ def make_test_metadata(
|
||||
encoder_test_params and cross_test_params arguments allow encoder
|
||||
attention and enc/dec cross-attention (respectively) to use distinct
|
||||
metadata values from decoder self-attention (decoder_test_params.)
|
||||
|
||||
|
||||
if encoder_test_params and cross_test_params are None, the attention
|
||||
metadata will support decoder-only scenario.
|
||||
|
||||
@@ -820,7 +820,7 @@ def make_test_metadata(
|
||||
* attn_backend_name: Backend for sourcing attention kernels
|
||||
* is_prompt: prefill if True, o/w decode
|
||||
* seq_lens: list of token counts for each sequence
|
||||
* decoder_test_params: decoder self-attention test params;
|
||||
* decoder_test_params: decoder self-attention test params;
|
||||
this function requires
|
||||
kv_mmap (memory mapping) field
|
||||
* device: CPU or CUDA device
|
||||
|
||||
Reference in New Issue
Block a user