[Bugfix][ROCm][GPT-OSS] Use old triton_kernels implementation on ROCm if the new API is not available (#34153)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
5e75a14a66
commit
c60f8e3b49
@@ -19,11 +19,14 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
use_legacy_triton_kernels = False
|
||||
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
import triton_kernels.swiglu
|
||||
@@ -38,10 +41,20 @@ if has_triton_kernels():
|
||||
from triton_kernels.tensor import (
|
||||
BIT,
|
||||
Bitmatrix,
|
||||
SparseMatrix,
|
||||
make_ragged_tensor_metadata,
|
||||
)
|
||||
from triton_kernels.topk import topk
|
||||
|
||||
try:
|
||||
from triton_kernels.tensor import (
|
||||
SparseMatrix,
|
||||
make_ragged_tensor_metadata,
|
||||
)
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
logger.warning_once("Using legacy triton_kernels on ROCm")
|
||||
use_legacy_triton_kernels = True
|
||||
else:
|
||||
raise
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
@@ -101,6 +114,12 @@ def legacy_routing_from_bitmatrix(
|
||||
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
||||
Creates routing data from a bitmatrix representation.
|
||||
"""
|
||||
if use_legacy_triton_kernels:
|
||||
from triton_kernels.routing import routing_from_bitmatrix
|
||||
|
||||
return routing_from_bitmatrix(
|
||||
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
|
||||
)
|
||||
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
@@ -130,6 +149,10 @@ def legacy_routing(
|
||||
Replacement for the removed triton_kernels.routing.routing function.
|
||||
Computes routing data from gating logits.
|
||||
"""
|
||||
if use_legacy_triton_kernels:
|
||||
from triton_kernels.routing import routing
|
||||
|
||||
return routing(logits, n_expts_act, sm_first=sm_first)
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
||||
@@ -231,11 +254,22 @@ def triton_kernel_fused_experts(
|
||||
)
|
||||
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs(
|
||||
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
|
||||
),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
act = (
|
||||
FusedActivation(
|
||||
FnSpecs(
|
||||
"swiglu",
|
||||
triton_kernels.swiglu.swiglu_fn,
|
||||
("alpha", "limit"),
|
||||
reduction_n=2,
|
||||
),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
)
|
||||
if not use_legacy_triton_kernels
|
||||
else FusedActivation(
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
2,
|
||||
)
|
||||
)
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
@@ -296,8 +330,17 @@ def make_routing_data(
|
||||
|
||||
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||
bitmatrix_shape_max = [n_rows, None]
|
||||
bitmatrix = Bitmatrix(
|
||||
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
||||
bitmatrix = (
|
||||
Bitmatrix(
|
||||
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
||||
)
|
||||
if not use_legacy_triton_kernels
|
||||
else Bitmatrix(
|
||||
bitmatrix,
|
||||
shape=bitmatrix_shape,
|
||||
shape_max=bitmatrix_shape_max,
|
||||
scratchpad=None,
|
||||
)
|
||||
)
|
||||
|
||||
# matmul_ogs expects invalid topk_weights to be -1s
|
||||
|
||||
Reference in New Issue
Block a user