diff --git a/.buildkite/image_build/image_build.yaml b/.buildkite/image_build/image_build.yaml index 3026467bf..163dd68c8 100644 --- a/.buildkite/image_build/image_build.yaml +++ b/.buildkite/image_build/image_build.yaml @@ -3,7 +3,6 @@ steps: - label: ":docker: Build image" key: image-build depends_on: [] - timeout_in_minutes: 600 commands: - if [[ "$BUILDKITE_BRANCH" != "main" ]]; then .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $IMAGE_TAG; fi - if [[ "$BUILDKITE_BRANCH" == "main" ]]; then .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $IMAGE_TAG $IMAGE_TAG_LATEST; fi @@ -42,7 +41,7 @@ steps: limit: 2 - exit_status: -10 # Agent was lost limit: 2 - + - label: ":docker: Build CPU arm64 image" key: cpu-arm64-image-build depends_on: [] diff --git a/CMakeLists.txt b/CMakeLists.txt index 168376ca1..0000b6d32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,8 +56,8 @@ endif() # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.10.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.10.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.9.1") +set(TORCH_SUPPORTED_VERSION_ROCM "2.9.1") # # Try to find python package with an executable that exactly matches @@ -433,7 +433,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) endif() - if (MARLIN_SM75_ARCHS) + if (MARLIN_SM75_ARCHS) file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/marlin/sm75_kernel_*.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}" @@ -445,7 +445,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC}) endif() - if (MARLIN_FP8_ARCHS) + if (MARLIN_FP8_ARCHS) file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/marlin/sm89_kernel_*.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" @@ -1042,7 +1042,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) endif() - if (MARLIN_MOE_SM75_ARCHS) + if (MARLIN_MOE_SM75_ARCHS) file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_MOE_SM75_SRC}" diff --git a/cmake/external_projects/triton_kernels.cmake b/cmake/external_projects/triton_kernels.cmake index 1d8b9779c..d35ad123d 100644 --- a/cmake/external_projects/triton_kernels.cmake +++ b/cmake/external_projects/triton_kernels.cmake @@ -1,9 +1,9 @@ # Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels -set(DEFAULT_TRITON_KERNELS_TAG "v3.6.0") +set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0") # Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to -# be directly set to the triton_kernels python directory. +# be directly set to the triton_kernels python directory. if (DEFINED ENV{TRITON_KERNELS_SRC_DIR}) message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}") FetchContent_Declare( @@ -24,7 +24,7 @@ else() ) endif() -# Fetch content +# Fetch content FetchContent_MakeAvailable(triton_kernels) if (NOT triton_kernels_SOURCE_DIR) @@ -47,7 +47,7 @@ install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/tr ## Copy .py files to install directory. install(DIRECTORY ${TRITON_KERNELS_PYTHON_DIR} - DESTINATION + DESTINATION vllm/third_party/triton_kernels/ COMPONENT triton_kernels FILES_MATCHING PATTERN "*.py") diff --git a/pyproject.toml b/pyproject.toml index b64254bf5..5fd603fe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<81.0.0", "setuptools-scm>=8.0", - "torch == 2.10.0", + "torch == 2.9.1", "wheel", "jinja2", "grpcio-tools==1.78.0", diff --git a/requirements/build.txt b/requirements/build.txt index 6c6c9fc8a..6d4376f15 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,7 @@ ninja packaging>=24.2 setuptools>=77.0.3,<81.0.0 setuptools-scm>=8 -torch==2.10.0 +torch==2.9.1 wheel jinja2>=3.1.6 regex diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 15e4ebbf4..960b4252f 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -5,9 +5,9 @@ numba == 0.61.2 # Required for N-gram speculative decoding # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 -torch==2.10.0 -torchaudio==2.10.0 +torch==2.9.1 +torchaudio==2.9.1 # These must be updated alongside torch -torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.6.3 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 01a71c2da..54af9d995 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,11 +1,12 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/test/rocm7.0 -torch==2.10.0 -torchvision==0.25.0 -torchaudio==2.10.0 -triton==3.6.0 +--extra-index-url https://download.pytorch.org/whl/rocm6.4 +torch==2.9.1 +torchvision==0.24.1 +torchaudio==2.9.1 + +triton==3.5.1 cmake>=3.26.1,<4 packaging>=24.2 setuptools>=77.0.3,<80.0.0 diff --git a/requirements/test.in b/requirements/test.in index 8a97c0e88..3ac97432e 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -24,10 +24,10 @@ sentence-transformers>=5.2.0 # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests tblib # for pickling test exceptions -timm >=1.0.17 # required for internvl and gemma3n-mm test -torch==2.10.0 -torchaudio==2.10.0 -torchvision==0.25.0 +timm==1.0.17 # required for internvl and gemma3n-mm test +torch==2.9.1 +torchaudio==2.9.1 +torchvision==0.24.1 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test mistral_common[image,audio] >= 1.9.0 # required for voxtral test diff --git a/requirements/test.txt b/requirements/test.txt index fbe3228d2..566b0e926 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -155,10 +155,6 @@ coverage==7.10.6 # via pytest-cov cramjam==2.9.0 # via fastparquet -cuda-bindings==12.9.4 - # via torch -cuda-pathfinder==1.3.3 - # via cuda-bindings cupy-cuda12x==13.6.0 # via ray cycler==0.12.1 @@ -635,7 +631,7 @@ nvidia-nvjitlink-cu12==12.9.86 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvshmem-cu12==3.4.5 +nvidia-nvshmem-cu12==3.3.20 # via torch nvidia-nvtx-cu12==12.9.79 # via torch @@ -1167,7 +1163,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.10.0+cu129 +torch==2.9.1+cu129 # via # -r requirements/test.in # accelerate @@ -1195,7 +1191,7 @@ torch==2.10.0+cu129 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.10.0+cu129 +torchaudio==2.9.1+cu129 # via # -r requirements/test.in # encodec @@ -1208,7 +1204,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.25.0+cu129 +torchvision==0.24.1+cu129 # via # -r requirements/test.in # lightly @@ -1250,7 +1246,7 @@ transformers==4.57.5 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.6.0 +triton==3.5.1 # via torch tritonclient==2.64.0 # via -r requirements/test.in diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index fbacbb6bf..3a7dab9fd 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -90,7 +90,9 @@ def use_vllm_config(vllm_config: VllmConfig): yield -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: vllm_config = make_vllm_config() @@ -114,7 +116,9 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): assert torch.allclose(actual, expected) -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_force_aot_load(monkeypatch: pytest.MonkeyPatch): with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m: args = (torch.randn(10, 10),) @@ -128,7 +132,9 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch): CompiledMod(vllm_config=vllm_config)(*args) -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_save_and_load(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: args = (torch.randn(10, 10),) @@ -156,7 +162,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch): assert torch.allclose(ret, expected) -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch): """ Test that cache loading correctly handles the returns_tuple logic. @@ -215,7 +223,9 @@ def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch): ) -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_cache_load_returns_tuple_consistency_tuple_output( monkeypatch: pytest.MonkeyPatch, ): @@ -284,7 +294,9 @@ def test_cache_load_returns_tuple_consistency_tuple_output( ) -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_shape_env(monkeypatch: pytest.MonkeyPatch): """ Test that the shape environment is correctly serialized and preserved @@ -321,7 +333,9 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_partition_wrapper_applied_on_aot_load( monkeypatch: pytest.MonkeyPatch, vllm_tmp_cache: Path, mocker ): @@ -412,7 +426,9 @@ def test_partition_wrapper_applied_on_aot_load( ) -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) @create_new_process_for_each_test("spawn") def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): """ @@ -476,7 +492,9 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): symbolic_shapes_module.make_symbol = original_make_symbol -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) class TestStandaloneCompiledArtifacts: def test_init(self): cache = StandaloneCompiledArtifacts() @@ -650,7 +668,9 @@ class TestStandaloneCompiledArtifacts: assert len(restored_cache.loaded_submodule_store) == 0 -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) class TestStandaloneCompiledArtifactsIntegration: def test_add_pickle_unpickle(self): cache = StandaloneCompiledArtifacts() diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py index 6dec603a5..1fda21dea 100644 --- a/tests/compile/test_dynamic_shapes_compilation.py +++ b/tests/compile/test_dynamic_shapes_compilation.py @@ -39,7 +39,9 @@ def get_test_models(): @pytest.mark.parametrize("use_aot_compile", ["0", "1"]) @pytest.mark.parametrize("use_bytecode_hook", [True, False]) @pytest.mark.parametrize("evaluate_guards", [False, True]) -@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) def test_dynamic_shapes_compilation( monkeypatch, model_name, diff --git a/tests/kernels/moe/test_shared_fused_moe_routed_transform.py b/tests/kernels/moe/test_shared_fused_moe_routed_transform.py index b6ef19dda..3be1d9974 100644 --- a/tests/kernels/moe/test_shared_fused_moe_routed_transform.py +++ b/tests/kernels/moe/test_shared_fused_moe_routed_transform.py @@ -14,7 +14,6 @@ import torch.nn as nn from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE -from vllm.utils.torch_utils import is_torch_equal_or_newer class SimpleLinear(nn.Module): @@ -61,10 +60,6 @@ def setup_cuda(): @pytest.mark.parametrize("num_tokens", [1, 32]) @pytest.mark.parametrize("hidden_size,latent_size", [(256, 128), (128, 64)]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.skipif( - is_torch_equal_or_newer("2.10.0"), - reason="Test fails with PyTorch 2.10.0 see: https://github.com/vllm-project/vllm/issues/33995", -) def test_routed_input_transform_inside_vs_outside( num_tokens: int, hidden_size: int, diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 1d5adb185..606503539 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -233,7 +233,7 @@ class InductorStandaloneAdaptor(CompilerInterface): from torch._inductor import standalone_compile - supports_aot = is_torch_equal_or_newer("2.10.0") + supports_aot = is_torch_equal_or_newer("2.10.0.dev") if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT: logger.error( diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 3651c835f..8bdb9882b 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -333,7 +333,7 @@ def _support_torch_compile( ) -> None: def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None: if ds_type == DynamicShapesType.UNBACKED: - if is_torch_equal_or_newer("2.10.0"): + if is_torch_equal_or_newer("2.10.0.dev"): for dim in dims: torch._dynamo.decorators.mark_unbacked( arg, dim, hint_override=arg.size()[dim] @@ -373,7 +373,7 @@ def _support_torch_compile( if isinstance(arg, torch.Tensor): # In case dims is specified with negative indexing dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - if is_torch_equal_or_newer("2.10.0"): + if is_torch_equal_or_newer("2.10.0.dev"): for dim in dims: torch._dynamo.decorators.mark_unbacked( arg, dim, hint_override=arg.size()[dim] @@ -525,9 +525,9 @@ def _support_torch_compile( fx_config_patches["backed_size_oblivious"] = True # Prepare inductor config patches - # assume_32bit_indexing is only available in torch 2.10.0+ + # assume_32bit_indexing is only available in torch 2.10.0.dev+ inductor_config_patches = {} - if is_torch_equal_or_newer("2.10.0"): + if is_torch_equal_or_newer("2.10.0.dev"): inductor_config_patches["assume_32bit_indexing"] = ( self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing ) diff --git a/vllm/envs.py b/vllm/envs.py index 19464f2f2..caddf0b76 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -271,7 +271,7 @@ def use_aot_compile() -> bool: default_value = ( "1" - if is_torch_equal_or_newer("2.11.0.dev") and not disable_compile_cache() + if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache() else "0" ) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index dbe8e8ef2..fcfadd60f 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -974,7 +974,7 @@ def enable_batch_invariant_mode(): ) reduced_precision_val = ( - (False, False) if is_torch_equal_or_newer("2.10.0") else False + (False, False) if is_torch_equal_or_newer("2.10.0.dev") else False ) torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( reduced_precision_val diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 3801814d9..b209820cd 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -27,21 +27,9 @@ logger = init_logger(__name__) if has_triton_kernels(): try: import triton_kernels.swiglu - from triton_kernels.matmul_ogs import ( - FnSpecs, - FusedActivation, - GatherIndx, - RoutingData, - ScatterIndx, - matmul_ogs, - ) - from triton_kernels.tensor import ( - BIT, - Bitmatrix, - SparseMatrix, - make_ragged_tensor_metadata, - ) - from triton_kernels.topk import topk + from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs + from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix + from triton_kernels.tensor import Bitmatrix except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -90,58 +78,6 @@ def pack_bitmatrix( tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) -def legacy_routing_from_bitmatrix( - bitmatrix: "Bitmatrix", - expt_scal: torch.Tensor, - expt_indx: torch.Tensor, - n_expts_tot: int, - n_expts_act: int, -) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]: - """ - Replacement for the removed triton_kernels.routing.routing_from_bitmatrix. - Creates routing data from a bitmatrix representation. - """ - sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix) - dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx - combine_indx = sparse_logits.mask_metadata.col_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata( - sparse_logits.mask_metadata.col_sum, - dispatch_indx.shape[0], - ) - gate_scal = sparse_logits.vals.flatten()[combine_indx] - routing_data = RoutingData( - gate_scal, - ragged_batch_metadata.block_sizes, - n_expts_tot, - n_expts_act, - ragged_batch_metadata, - ) - gather_idx = GatherIndx(combine_indx, dispatch_indx) - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) - return routing_data, gather_idx, scatter_idx - - -def legacy_routing( - logits: torch.Tensor, - n_expts_act: int, - sm_first: bool = False, -) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]: - """ - Replacement for the removed triton_kernels.routing.routing function. - Computes routing data from gating logits. - """ - if sm_first: - logits = torch.softmax(logits, dim=-1) - sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first) - return legacy_routing_from_bitmatrix( - sparse_logits.mask, - sparse_logits.vals, - sparse_logits.indx, - logits.shape[-1], - n_expts_act, - ) - - def triton_kernel_moe_forward( hidden_states: torch.Tensor, w1, # Tensor or triton_kernels.Tensor @@ -155,7 +91,7 @@ def triton_kernel_moe_forward( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, ) -> torch.Tensor: - routing_data, gather_idx, scatter_idx = legacy_routing( + routing_data, gather_idx, scatter_idx = routing( gating_output, topk, sm_first=not renormalize ) @@ -232,10 +168,9 @@ def triton_kernel_fused_experts( output_tensor = _resize_cache(output_tensor, (batch_dim, M, K)) act = FusedActivation( - FnSpecs( - "swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2 - ), + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (swiglu_alpha, swiglu_limit), + 2, ) gammas = routing_data.gate_scal if routing_data else None @@ -297,12 +232,12 @@ def make_routing_data( bitmatrix_shape = [n_rows, bm_cols * 32] bitmatrix_shape_max = [n_rows, None] bitmatrix = Bitmatrix( - bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max + bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None ) # matmul_ogs expects invalid topk_weights to be -1s topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) - routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix( + routing_data, gather_indx, scatter_indx = routing_from_bitmatrix( bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk )