[TPU] Avoid Triton Import (#15589)

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
Robert Shaw
2025-03-27 02:43:02 -04:00
committed by GitHub
parent df8d3d1287
commit e1e0fd7543
2 changed files with 8 additions and 6 deletions

View File

@@ -16,8 +16,6 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
@@ -119,7 +117,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(