[ROCm] Fix GPT-OSS import for triton 3.6 (#37453)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
0e9358c11d
commit
b8665383df
@@ -48,9 +48,16 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
|
||||
|
||||
value_layout = StridedLayout
|
||||
if on_gfx950():
|
||||
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
|
||||
try:
|
||||
# triton < 3.6
|
||||
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
|
||||
|
||||
scale_layout = GFX950MXScaleLayout
|
||||
scale_layout = GFX950MXScaleLayout
|
||||
except ImportError:
|
||||
# triton >= 3.6
|
||||
from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout
|
||||
|
||||
scale_layout = CDNA4MXScaleLayout
|
||||
else:
|
||||
scale_layout = StridedLayout
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user