[Bugfix]: Fix cross attention backend selection for Turing GPU (#31806)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -149,16 +149,20 @@ class CrossAttention(Attention):
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_cross_attention_backend(underlying_attn_backend)
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||
"CrossAttention only supports AttentionType.ENCODER_DECODER"
|
||||
)
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
attn_backend = create_cross_attention_backend(underlying_attn_backend)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
|
||||
Reference in New Issue
Block a user