[ROCm] Fix fused_moe_fake signature mismatch and other AITER bugs (#36100)

Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
Chuan (Richard) Li
2026-03-23 00:48:31 -07:00
committed by GitHub
parent a16133a0f1
commit e99fb98867
4 changed files with 16 additions and 26 deletions

View File

@@ -157,13 +157,13 @@ if current_platform.is_rocm():
total_tokens: int,
):
assert kv_cache_layout in ["NHD", "SHUFFLE"], (
"kv_cache_layout only support NHD, SHUFFLE"
"kv_cache_layout only supports NHD, SHUFFLE"
)
head_dim = key.shape[2]
x = 16 // key_cache.element_size()
# assert dequant is True, "Currently, we only support "\
# "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
# For k cache layout: [num_blocks, page_size, num_heads, head_dim]
assert head_dim == key_cache.shape[3], (
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
@@ -832,7 +832,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for FlashAttentionImpl"
"Encoder self-attention is not implemented for AiterFlashAttentionImpl"
)
def extend_for_sliding_window(
@@ -1047,7 +1047,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for FlashAttentionImpl"
"fused output quantization is not yet supported "
"for AiterFlashAttentionImpl"
)
if attn_metadata is None: