Support FP32 (#141)

This commit is contained in:
Woosuk Kwon
2023-06-07 00:40:21 -07:00
committed by GitHub
parent 376725ce74
commit e38074b1e6
8 changed files with 65 additions and 54 deletions

View File

@@ -10,7 +10,7 @@ from cacheflow import cache_ops
from cacheflow import pos_encoding_ops
from cacheflow.model_executor.input_metadata import InputMetadata
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class GPTCacheFlowAttention(nn.Module):
@@ -49,10 +49,8 @@ class GPTCacheFlowAttention(nn.Module):
self.attn_op = xops.fmha.cutlass.FwOp()
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f'head_size ({self.head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{_SUPPORTED_HEAD_SIZES}.')
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def multi_query_kv_attention(
self,