[Bugfix][ROCm][GPT-OSS] Use old triton_kernels implementation on ROCm if the new API is not available (#34153)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-02-09 17:38:54 -06:00
committed by GitHub
parent 5e75a14a66
commit c60f8e3b49

View File

@@ -19,11 +19,14 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
logger = init_logger(__name__)
use_legacy_triton_kernels = False
if has_triton_kernels():
try:
import triton_kernels.swiglu
@@ -38,10 +41,20 @@ if has_triton_kernels():
from triton_kernels.tensor import (
BIT,
Bitmatrix,
SparseMatrix,
make_ragged_tensor_metadata,
)
from triton_kernels.topk import topk
try:
from triton_kernels.tensor import (
SparseMatrix,
make_ragged_tensor_metadata,
)
except ImportError:
if current_platform.is_rocm():
logger.warning_once("Using legacy triton_kernels on ROCm")
use_legacy_triton_kernels = True
else:
raise
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
@@ -101,6 +114,12 @@ def legacy_routing_from_bitmatrix(
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing_from_bitmatrix
return routing_from_bitmatrix(
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
)
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
@@ -130,6 +149,10 @@ def legacy_routing(
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing
return routing(logits, n_expts_act, sm_first=sm_first)
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
@@ -231,11 +254,22 @@ def triton_kernel_fused_experts(
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
act = FusedActivation(
FnSpecs(
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
),
(swiglu_alpha, swiglu_limit),
act = (
FusedActivation(
FnSpecs(
"swiglu",
triton_kernels.swiglu.swiglu_fn,
("alpha", "limit"),
reduction_n=2,
),
(swiglu_alpha, swiglu_limit),
)
if not use_legacy_triton_kernels
else FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
2,
)
)
gammas = routing_data.gate_scal if routing_data else None
@@ -296,8 +330,17 @@ def make_routing_data(
bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = Bitmatrix(
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
bitmatrix = (
Bitmatrix(
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
)
if not use_legacy_triton_kernels
else Bitmatrix(
bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=None,
)
)
# matmul_ogs expects invalid topk_weights to be -1s