[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
from array import array
|
||||
from typing import List
|
||||
|
||||
@@ -7,13 +8,9 @@ import torch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.utils import is_cpu, make_tensor_with_pad
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
|
||||
# CUDA graph scenarios to test
|
||||
#
|
||||
# Currently CUDA graph is not supported
|
||||
ENFORCE_EAGER = [True]
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
BATCH_SIZES = [1, 4, 16, 64, 256]
|
||||
|
||||
@@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args,
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_empty_seq_group(enforce_eager, ):
|
||||
def test_empty_seq_group():
|
||||
"""Verify prepare prompt and decode returns empty output
|
||||
for empty seq group list"""
|
||||
|
||||
@@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ):
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
enforce_eager=True,
|
||||
)
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
@@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ):
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_prepare_prompt(
|
||||
batch_size,
|
||||
enforce_eager,
|
||||
):
|
||||
def test_prepare_prompt(batch_size):
|
||||
'''
|
||||
Test the ability of the encoder/decoder model runner subclass to
|
||||
produce prefill-phase model inputs & attention metadata.
|
||||
@@ -115,7 +107,7 @@ def test_prepare_prompt(
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
@@ -281,11 +273,7 @@ def test_prepare_prompt(
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_prepare_decode(
|
||||
batch_size,
|
||||
enforce_eager,
|
||||
):
|
||||
def test_prepare_decode(batch_size):
|
||||
'''
|
||||
Test the ability of the encoder/decoder model runner subclass to
|
||||
produce decode-phase model inputs & attention metadata.
|
||||
@@ -311,7 +299,7 @@ def test_prepare_decode(
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
@@ -428,7 +416,8 @@ def test_prepare_decode(
|
||||
expected,
|
||||
)
|
||||
|
||||
# Cuda graph should is currently not supported for encoder/decoer.
|
||||
# Model runner's CUDAGraph setting should be propagated to attention
|
||||
# metadata.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
# Verify the lengths of input tokens & positions
|
||||
@@ -484,3 +473,152 @@ def test_prepare_decode(
|
||||
dtype=actual.dtype,
|
||||
)
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_decode_cuda_graph(batch_size):
|
||||
"""
|
||||
Tests that for encoder-decoder models with CUDA Graph capture and replay
|
||||
enabled, the tensors used during the decode phase are correctly padded
|
||||
for varying input batch sizes.
|
||||
"""
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/bart-base",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=False,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
encoder_seq_lens: List[int] = []
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
cross_block_table = [2]
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_lens.append(encoder_seq_len)
|
||||
encoder_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=cross_block_table,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
encoder_input_tokens = model_input.encoder_input_tokens
|
||||
encoder_input_positions = model_input.encoder_input_positions
|
||||
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
|
||||
# With CUDA Graph capture and replay enabled, the decoder and encoder
|
||||
# input sequences will be padded. Create the expected padded tensors
|
||||
# accordingly.
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
||||
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
|
||||
padded_encoder_seq_lens = encoder_seq_lens + list(
|
||||
itertools.repeat(1, cuda_graph_pad_size))
|
||||
|
||||
assert return_seq_lens == padded_seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
||||
|
||||
# Verify attention metadata
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_decode_tokens > 0
|
||||
assert torch.equal(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(padded_seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == padded_seq_lens
|
||||
assert attn_metadata.max_prefill_seq_len == 0
|
||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||
# - Encoder attention metadata
|
||||
assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
|
||||
assert torch.equal(
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
|
||||
assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
|
||||
|
||||
# Verify block tables are correct for prompts
|
||||
# - Decoder self-attention. Pad the block tables as expected.
|
||||
expected = [block_tables[0] for _ in range(batch_size)]
|
||||
expected.extend([[] for _ in range(cuda_graph_pad_size)])
|
||||
expected = make_tensor_with_pad(
|
||||
expected,
|
||||
max_len=64,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
assert torch.equal(
|
||||
attn_metadata.block_tables,
|
||||
expected,
|
||||
)
|
||||
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
|
||||
# as expected.
|
||||
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
|
||||
expected.extend([[] for _ in range(cuda_graph_pad_size)])
|
||||
expected = make_tensor_with_pad(
|
||||
expected,
|
||||
max_len=64,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
assert torch.equal(
|
||||
attn_metadata.cross_block_tables,
|
||||
expected,
|
||||
)
|
||||
|
||||
# Model runner's CUDAGraph setting should be propagated to attention
|
||||
# metadata.
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
|
||||
# Verify the lengths of input tokens & positions
|
||||
# - Decoder
|
||||
assert len(input_tokens) == len(padded_seq_lens)
|
||||
assert len(input_positions) == len(padded_seq_lens)
|
||||
# -- An indirect check that model_input.input_tokens
|
||||
# and model_input.input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
)
|
||||
# - Encoder
|
||||
assert len(encoder_input_tokens) == 0
|
||||
assert len(encoder_input_tokens) == 0
|
||||
# -- An indirect check that model_input.encoder_input_tokens
|
||||
# and model_input.encoder_input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user