Aiter mha fp8 fix (#24991)
Signed-off-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com>
This commit is contained in:
@@ -81,8 +81,8 @@ class AITERPagedAttention(PagedAttention):
|
|||||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||||
|
|
||||||
if "fp8" in kv_cache_dtype:
|
if "fp8" in kv_cache_dtype:
|
||||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
# use blocksparse paged attention
|
# use blocksparse paged attention
|
||||||
|
|||||||
@@ -479,8 +479,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
if not attn_metadata.use_cascade:
|
if not attn_metadata.use_cascade:
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
|
|||||||
Reference in New Issue
Block a user