[V1] Make v1 more testable (#9888)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -16,7 +16,7 @@ from tests.kernels.utils import *
|
||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend, get_attn_backend,
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
@@ -774,7 +774,7 @@ def set_reset_environment(attn_backend):
|
||||
default_dtype = torch.get_default_dtype()
|
||||
if attn_backend.name == 'FLASH_ATTN':
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
get_attn_backend.cache_clear()
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
yield
|
||||
# Reset the torch datatype to what it was before the test
|
||||
# so as not to impact the remaining tests.
|
||||
|
||||
Reference in New Issue
Block a user