[Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix (#18100)

Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
Hongxia Yang
2025-05-28 19:21:46 -04:00
committed by GitHub
parent 7951d78738
commit 269d901734
2 changed files with 54 additions and 56 deletions

View File

@@ -13,7 +13,9 @@ HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]