[attn][tiny fix] fix attn backend in MultiHeadAttention (#11463)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -191,6 +191,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
kv_cache_dtype=None,
|
kv_cache_dtype=None,
|
||||||
block_size=16,
|
block_size=16,
|
||||||
is_attention_free=False)
|
is_attention_free=False)
|
||||||
|
attn_backend = backend_name_to_enum(attn_backend.get_name())
|
||||||
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||||
attn_backend = _Backend.XFORMERS
|
attn_backend = _Backend.XFORMERS
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user