[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)
This commit is contained in:
@@ -7,12 +7,18 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import DecoderPromptType
|
||||
from ..models.utils import check_logprobs_close
|
||||
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [
|
||||
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(
|
||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
||||
@@ -29,7 +35,8 @@ def vllm_to_hf_output(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||
@@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
|
||||
num_logprobs: int,
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
enforce_eager: bool,
|
||||
attn_backend: _Backend,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-End (E2E) test for the encoder-decoder framework.
|
||||
@@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
|
||||
implementations to ensure that both implementations produce consistent
|
||||
and correct results.
|
||||
'''
|
||||
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
test_case_prompts = example_encoder_decoder_prompts[
|
||||
decoder_prompt_type]
|
||||
|
||||
# Configuration settings for HF baseline
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
# Configuration settings for HF baseline
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_case_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
enforce_eager=enforce_eager) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
test_case_prompts, max_tokens, num_logprobs)
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_outputs = (
|
||||
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_case_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
enforce_eager=enforce_eager) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
test_case_prompts, max_tokens, num_logprobs)
|
||||
|
||||
hf_skip_tokens = (1
|
||||
if decoder_prompt_type == DecoderPromptType.NONE else 0)
|
||||
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
|
||||
else 0)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||
)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user