[ROCm] Fix GPT-OSS import for triton 3.6 (#37453)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-03-27 13:00:57 -05:00
committed by GitHub
parent 0e9358c11d
commit b8665383df

View File

@@ -48,9 +48,16 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
value_layout = StridedLayout
if on_gfx950():
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
try:
# triton < 3.6
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
scale_layout = GFX950MXScaleLayout
scale_layout = GFX950MXScaleLayout
except ImportError:
# triton >= 3.6
from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout
scale_layout = CDNA4MXScaleLayout
else:
scale_layout = StridedLayout
else: