Support FP32 (#141)
This commit is contained in:
@@ -10,7 +10,7 @@ from cacheflow import cache_ops
|
||||
from cacheflow import pos_encoding_ops
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
|
||||
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
@@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f'head_size ({self.head_size}) is not supported by '
|
||||
'the single_query_cached_kv_attention kernel. '
|
||||
'Use one of the following head sizes: '
|
||||
f'{_SUPPORTED_HEAD_SIZES}.')
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user