Revert "[Release 2.10] Update to Torch 2.10 - final release (#30525)"
This reverts commit f97ca67176.
This commit is contained in:
@@ -3,7 +3,6 @@ steps:
|
||||
- label: ":docker: Build image"
|
||||
key: image-build
|
||||
depends_on: []
|
||||
timeout_in_minutes: 600
|
||||
commands:
|
||||
- if [[ "$BUILDKITE_BRANCH" != "main" ]]; then .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $IMAGE_TAG; fi
|
||||
- if [[ "$BUILDKITE_BRANCH" == "main" ]]; then .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $IMAGE_TAG $IMAGE_TAG_LATEST; fi
|
||||
|
||||
@@ -56,8 +56,8 @@ endif()
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.10.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.10.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.9.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.9.1")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
|
||||
|
||||
set(DEFAULT_TRITON_KERNELS_TAG "v3.6.0")
|
||||
set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0")
|
||||
|
||||
# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
|
||||
# be directly set to the triton_kernels python directory.
|
||||
|
||||
@@ -6,7 +6,7 @@ requires = [
|
||||
"packaging>=24.2",
|
||||
"setuptools>=77.0.3,<81.0.0",
|
||||
"setuptools-scm>=8.0",
|
||||
"torch == 2.10.0",
|
||||
"torch == 2.9.1",
|
||||
"wheel",
|
||||
"jinja2",
|
||||
"grpcio-tools==1.78.0",
|
||||
|
||||
@@ -4,7 +4,7 @@ ninja
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools-scm>=8
|
||||
torch==2.10.0
|
||||
torch==2.9.1
|
||||
wheel
|
||||
jinja2>=3.1.6
|
||||
regex
|
||||
|
||||
@@ -5,9 +5,9 @@ numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
|
||||
# Dependencies for NVIDIA GPUs
|
||||
ray[cgraph]>=2.48.0
|
||||
torch==2.10.0
|
||||
torchaudio==2.10.0
|
||||
torch==2.9.1
|
||||
torchaudio==2.9.1
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.6.3
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Common dependencies
|
||||
-r common.txt
|
||||
|
||||
--extra-index-url https://download.pytorch.org/whl/test/rocm7.0
|
||||
torch==2.10.0
|
||||
torchvision==0.25.0
|
||||
torchaudio==2.10.0
|
||||
triton==3.6.0
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
||||
torch==2.9.1
|
||||
torchvision==0.24.1
|
||||
torchaudio==2.9.1
|
||||
|
||||
triton==3.5.1
|
||||
cmake>=3.26.1,<4
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
|
||||
@@ -24,10 +24,10 @@ sentence-transformers>=5.2.0 # required for embedding tests
|
||||
soundfile # required for audio tests
|
||||
jiwer # required for audio tests
|
||||
tblib # for pickling test exceptions
|
||||
timm >=1.0.17 # required for internvl and gemma3n-mm test
|
||||
torch==2.10.0
|
||||
torchaudio==2.10.0
|
||||
torchvision==0.25.0
|
||||
timm==1.0.17 # required for internvl and gemma3n-mm test
|
||||
torch==2.9.1
|
||||
torchaudio==2.9.1
|
||||
torchvision==0.24.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.9.0 # required for voxtral test
|
||||
|
||||
@@ -155,10 +155,6 @@ coverage==7.10.6
|
||||
# via pytest-cov
|
||||
cramjam==2.9.0
|
||||
# via fastparquet
|
||||
cuda-bindings==12.9.4
|
||||
# via torch
|
||||
cuda-pathfinder==1.3.3
|
||||
# via cuda-bindings
|
||||
cupy-cuda12x==13.6.0
|
||||
# via ray
|
||||
cycler==0.12.1
|
||||
@@ -635,7 +631,7 @@ nvidia-nvjitlink-cu12==12.9.86
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
# torch
|
||||
nvidia-nvshmem-cu12==3.4.5
|
||||
nvidia-nvshmem-cu12==3.3.20
|
||||
# via torch
|
||||
nvidia-nvtx-cu12==12.9.79
|
||||
# via torch
|
||||
@@ -1167,7 +1163,7 @@ tomli==2.2.1
|
||||
# via schemathesis
|
||||
tomli-w==1.2.0
|
||||
# via schemathesis
|
||||
torch==2.10.0+cu129
|
||||
torch==2.9.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
@@ -1195,7 +1191,7 @@ torch==2.10.0+cu129
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
torchaudio==2.10.0+cu129
|
||||
torchaudio==2.9.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
@@ -1208,7 +1204,7 @@ torchmetrics==1.7.4
|
||||
# pytorch-lightning
|
||||
# terratorch
|
||||
# torchgeo
|
||||
torchvision==0.25.0+cu129
|
||||
torchvision==0.24.1+cu129
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightly
|
||||
@@ -1250,7 +1246,7 @@ transformers==4.57.5
|
||||
# transformers-stream-generator
|
||||
transformers-stream-generator==0.0.5
|
||||
# via -r requirements/test.in
|
||||
triton==3.6.0
|
||||
triton==3.5.1
|
||||
# via torch
|
||||
tritonclient==2.64.0
|
||||
# via -r requirements/test.in
|
||||
|
||||
@@ -90,7 +90,9 @@ def use_vllm_config(vllm_config: VllmConfig):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
vllm_config = make_vllm_config()
|
||||
@@ -114,7 +116,9 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
|
||||
args = (torch.randn(10, 10),)
|
||||
@@ -128,7 +132,9 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
args = (torch.randn(10, 10),)
|
||||
@@ -156,7 +162,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that cache loading correctly handles the returns_tuple logic.
|
||||
@@ -215,7 +223,9 @@ def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_cache_load_returns_tuple_consistency_tuple_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
@@ -284,7 +294,9 @@ def test_cache_load_returns_tuple_consistency_tuple_output(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that the shape environment is correctly serialized and preserved
|
||||
@@ -321,7 +333,9 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_partition_wrapper_applied_on_aot_load(
|
||||
monkeypatch: pytest.MonkeyPatch, vllm_tmp_cache: Path, mocker
|
||||
):
|
||||
@@ -412,7 +426,9 @@ def test_partition_wrapper_applied_on_aot_load(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
@@ -476,7 +492,9 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
symbolic_shapes_module.make_symbol = original_make_symbol
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
class TestStandaloneCompiledArtifacts:
|
||||
def test_init(self):
|
||||
cache = StandaloneCompiledArtifacts()
|
||||
@@ -650,7 +668,9 @@ class TestStandaloneCompiledArtifacts:
|
||||
assert len(restored_cache.loaded_submodule_store) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
class TestStandaloneCompiledArtifactsIntegration:
|
||||
def test_add_pickle_unpickle(self):
|
||||
cache = StandaloneCompiledArtifacts()
|
||||
|
||||
@@ -39,7 +39,9 @@ def get_test_models():
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_dynamic_shapes_compilation(
|
||||
monkeypatch,
|
||||
model_name,
|
||||
|
||||
@@ -14,7 +14,6 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
class SimpleLinear(nn.Module):
|
||||
@@ -61,10 +60,6 @@ def setup_cuda():
|
||||
@pytest.mark.parametrize("num_tokens", [1, 32])
|
||||
@pytest.mark.parametrize("hidden_size,latent_size", [(256, 128), (128, 64)])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.skipif(
|
||||
is_torch_equal_or_newer("2.10.0"),
|
||||
reason="Test fails with PyTorch 2.10.0 see: https://github.com/vllm-project/vllm/issues/33995",
|
||||
)
|
||||
def test_routed_input_transform_inside_vs_outside(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
|
||||
@@ -233,7 +233,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
supports_aot = is_torch_equal_or_newer("2.10.0")
|
||||
supports_aot = is_torch_equal_or_newer("2.10.0.dev")
|
||||
|
||||
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
logger.error(
|
||||
|
||||
@@ -333,7 +333,7 @@ def _support_torch_compile(
|
||||
) -> None:
|
||||
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
@@ -373,7 +373,7 @@ def _support_torch_compile(
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
@@ -525,9 +525,9 @@ def _support_torch_compile(
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
# Prepare inductor config patches
|
||||
# assume_32bit_indexing is only available in torch 2.10.0+
|
||||
# assume_32bit_indexing is only available in torch 2.10.0.dev+
|
||||
inductor_config_patches = {}
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
inductor_config_patches["assume_32bit_indexing"] = (
|
||||
self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
|
||||
)
|
||||
|
||||
@@ -271,7 +271,7 @@ def use_aot_compile() -> bool:
|
||||
|
||||
default_value = (
|
||||
"1"
|
||||
if is_torch_equal_or_newer("2.11.0.dev") and not disable_compile_cache()
|
||||
if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
|
||||
else "0"
|
||||
)
|
||||
|
||||
|
||||
@@ -974,7 +974,7 @@ def enable_batch_invariant_mode():
|
||||
)
|
||||
|
||||
reduced_precision_val = (
|
||||
(False, False) if is_torch_equal_or_newer("2.10.0") else False
|
||||
(False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
reduced_precision_val
|
||||
|
||||
@@ -27,21 +27,9 @@ logger = init_logger(__name__)
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import (
|
||||
FnSpecs,
|
||||
FusedActivation,
|
||||
GatherIndx,
|
||||
RoutingData,
|
||||
ScatterIndx,
|
||||
matmul_ogs,
|
||||
)
|
||||
from triton_kernels.tensor import (
|
||||
BIT,
|
||||
Bitmatrix,
|
||||
SparseMatrix,
|
||||
make_ragged_tensor_metadata,
|
||||
)
|
||||
from triton_kernels.topk import topk
|
||||
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
|
||||
from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix
|
||||
from triton_kernels.tensor import Bitmatrix
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
@@ -90,58 +78,6 @@ def pack_bitmatrix(
|
||||
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
|
||||
|
||||
|
||||
def legacy_routing_from_bitmatrix(
|
||||
bitmatrix: "Bitmatrix",
|
||||
expt_scal: torch.Tensor,
|
||||
expt_indx: torch.Tensor,
|
||||
n_expts_tot: int,
|
||||
n_expts_act: int,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
||||
Creates routing data from a bitmatrix representation.
|
||||
"""
|
||||
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
ragged_batch_metadata = make_ragged_tensor_metadata(
|
||||
sparse_logits.mask_metadata.col_sum,
|
||||
dispatch_indx.shape[0],
|
||||
)
|
||||
gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
||||
routing_data = RoutingData(
|
||||
gate_scal,
|
||||
ragged_batch_metadata.block_sizes,
|
||||
n_expts_tot,
|
||||
n_expts_act,
|
||||
ragged_batch_metadata,
|
||||
)
|
||||
gather_idx = GatherIndx(combine_indx, dispatch_indx)
|
||||
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
|
||||
return routing_data, gather_idx, scatter_idx
|
||||
|
||||
|
||||
def legacy_routing(
|
||||
logits: torch.Tensor,
|
||||
n_expts_act: int,
|
||||
sm_first: bool = False,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Replacement for the removed triton_kernels.routing.routing function.
|
||||
Computes routing data from gating logits.
|
||||
"""
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
||||
return legacy_routing_from_bitmatrix(
|
||||
sparse_logits.mask,
|
||||
sparse_logits.vals,
|
||||
sparse_logits.indx,
|
||||
logits.shape[-1],
|
||||
n_expts_act,
|
||||
)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
@@ -155,7 +91,7 @@ def triton_kernel_moe_forward(
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
routing_data, gather_idx, scatter_idx = legacy_routing(
|
||||
routing_data, gather_idx, scatter_idx = routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
|
||||
@@ -232,10 +168,9 @@ def triton_kernel_fused_experts(
|
||||
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs(
|
||||
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
|
||||
),
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
2,
|
||||
)
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
@@ -297,12 +232,12 @@ def make_routing_data(
|
||||
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||
bitmatrix_shape_max = [n_rows, None]
|
||||
bitmatrix = Bitmatrix(
|
||||
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
||||
bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None
|
||||
)
|
||||
|
||||
# matmul_ogs expects invalid topk_weights to be -1s
|
||||
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
||||
routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix(
|
||||
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user