[Performance][Model Loader] Skip non-local expert weights during EP model loading (#37136)

Signed-off-by: esmeetu <jasonailu87@gmail.com>
This commit is contained in:
Roy Wang
2026-03-16 16:33:36 +08:00
committed by GitHub
parent a2956a0f8e
commit 821eb80c0d
4 changed files with 513 additions and 2 deletions

View File

@@ -0,0 +1,361 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for EP weight filtering during model loading."""
import glob
import tempfile
import huggingface_hub.constants
import pytest
import torch
from vllm.model_executor.model_loader.ep_weight_filter import (
compute_local_expert_ids,
parse_expert_id,
should_skip_weight,
)
from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator,
)
# ---------------------------------------------------------------------------
# Unit tests for parse_expert_id
# ---------------------------------------------------------------------------
class TestParseExpertId:
def test_routed_expert(self):
name = "model.layers.0.mlp.experts.42.gate_proj.weight"
assert parse_expert_id(name) == 42
def test_large_expert_id(self):
name = "model.layers.60.mlp.experts.383.down_proj.weight"
assert parse_expert_id(name) == 383
def test_shared_expert(self):
# Shared experts use a different naming convention in most models
name = "model.layers.0.mlp.shared_experts.gate_proj.weight"
assert parse_expert_id(name) is None
def test_attention_weight(self):
name = "model.layers.0.self_attn.q_proj.weight"
assert parse_expert_id(name) is None
def test_embedding(self):
name = "model.embed_tokens.weight"
assert parse_expert_id(name) is None
def test_layernorm(self):
name = "model.layers.0.input_layernorm.weight"
assert parse_expert_id(name) is None
def test_fused_3d_expert(self):
# 3D fused-expert tensors (e.g. gpt-oss) have no numeric expert id.
# They must NOT be filtered — slicing happens later in weight_loader.
name = "model.layers.0.mlp.experts.gate_proj.weight"
assert parse_expert_id(name) is None
def test_fused_3d_expert_down_proj(self):
name = "model.layers.10.mlp.experts.down_proj.weight"
assert parse_expert_id(name) is None
def test_expert_scale(self):
# NVFP4 quantized models have scale tensors for experts
name = "model.layers.5.mlp.experts.100.gate_proj.weight_scale"
assert parse_expert_id(name) == 100
def test_expert_zero_id(self):
name = "model.layers.0.mlp.experts.0.up_proj.weight"
assert parse_expert_id(name) == 0
# ---------------------------------------------------------------------------
# Unit tests for compute_local_expert_ids
# ---------------------------------------------------------------------------
class TestComputeLocalExpertIds:
def test_ep_disabled(self):
assert compute_local_expert_ids(64, ep_size=1, ep_rank=0) is None
def test_even_split(self):
# 64 experts, EP=8 → 8 per rank
ids = compute_local_expert_ids(64, ep_size=8, ep_rank=0)
assert ids == set(range(0, 8))
ids = compute_local_expert_ids(64, ep_size=8, ep_rank=7)
assert ids == set(range(56, 64))
def test_uneven_split(self):
# 10 experts, EP=3 → ranks get 4, 3, 3
ids_0 = compute_local_expert_ids(10, ep_size=3, ep_rank=0)
ids_1 = compute_local_expert_ids(10, ep_size=3, ep_rank=1)
ids_2 = compute_local_expert_ids(10, ep_size=3, ep_rank=2)
assert len(ids_0) == 4
assert len(ids_1) == 3
assert len(ids_2) == 3
# All experts covered, no overlap
assert ids_0 | ids_1 | ids_2 == set(range(10))
assert ids_0.isdisjoint(ids_1)
assert ids_1.isdisjoint(ids_2)
def test_384_experts_ep8(self):
# Kimi-K2.5 config: 384 experts, EP=8
for rank in range(8):
ids = compute_local_expert_ids(384, ep_size=8, ep_rank=rank)
assert len(ids) == 48
# All experts covered
all_ids = set()
for rank in range(8):
ids = compute_local_expert_ids(384, ep_size=8, ep_rank=rank)
all_ids |= ids
assert all_ids == set(range(384))
def test_384_experts_ep16(self):
for rank in range(16):
ids = compute_local_expert_ids(384, ep_size=16, ep_rank=rank)
assert len(ids) == 24
def test_384_experts_ep24(self):
# 384 / 24 = 16 exactly
for rank in range(24):
ids = compute_local_expert_ids(384, ep_size=24, ep_rank=rank)
assert len(ids) == 16
# round_robin placement tests
def test_round_robin_basic(self):
# 8 experts, EP=2: rank 0 → {0,2,4,6}, rank 1 → {1,3,5,7}
rr = "round_robin"
ids_0 = compute_local_expert_ids(8, 2, 0, placement=rr)
ids_1 = compute_local_expert_ids(8, 2, 1, placement=rr)
assert ids_0 == {0, 2, 4, 6}
assert ids_1 == {1, 3, 5, 7}
def test_round_robin_full_coverage(self):
# 384 experts, EP=8: all experts covered, no overlap
rr = "round_robin"
all_ids: set[int] = set()
for rank in range(8):
ids = compute_local_expert_ids(384, 8, rank, placement=rr)
assert ids is not None and len(ids) == 48
assert all_ids.isdisjoint(ids)
all_ids |= ids
assert all_ids == set(range(384))
def test_round_robin_uneven(self):
# 10 experts, EP=3: rank 0→{0,3,6,9}, rank 1→{1,4,7}, rank 2→{2,5,8}
rr = "round_robin"
ids_0 = compute_local_expert_ids(10, 3, 0, placement=rr)
ids_1 = compute_local_expert_ids(10, 3, 1, placement=rr)
ids_2 = compute_local_expert_ids(10, 3, 2, placement=rr)
assert ids_0 == {0, 3, 6, 9}
assert ids_1 == {1, 4, 7}
assert ids_2 == {2, 5, 8}
assert ids_0 | ids_1 | ids_2 == set(range(10))
# ---------------------------------------------------------------------------
# Unit tests for should_skip_weight
# ---------------------------------------------------------------------------
class TestShouldSkipWeight:
def setup_method(self):
# Simulate EP=8, rank=0 → experts 0-47
self.local_ids = compute_local_expert_ids(384, ep_size=8, ep_rank=0)
def test_no_filter(self):
assert not should_skip_weight("anything", None)
def test_dense_not_skipped(self):
assert not should_skip_weight(
"model.layers.0.self_attn.q_proj.weight", self.local_ids
)
def test_local_expert_not_skipped(self):
assert not should_skip_weight(
"model.layers.0.mlp.experts.10.gate_proj.weight", self.local_ids
)
def test_remote_expert_skipped(self):
assert should_skip_weight(
"model.layers.0.mlp.experts.200.gate_proj.weight", self.local_ids
)
def test_boundary_expert(self):
# Expert 47 is local (last one), 48 is not
assert not should_skip_weight(
"model.layers.0.mlp.experts.47.gate_proj.weight", self.local_ids
)
assert should_skip_weight(
"model.layers.0.mlp.experts.48.gate_proj.weight", self.local_ids
)
def test_shared_expert_not_skipped(self):
assert not should_skip_weight(
"model.layers.0.mlp.shared_experts.gate_proj.weight", self.local_ids
)
def test_embedding_not_skipped(self):
assert not should_skip_weight("model.embed_tokens.weight", self.local_ids)
def test_fused_3d_expert_not_skipped(self):
# 3D fused-expert tensors (gpt-oss style) have no numeric id.
# Must not be skipped — weight_loader handles slicing later.
assert not should_skip_weight(
"model.layers.0.mlp.experts.gate_proj.weight", self.local_ids
)
# ---------------------------------------------------------------------------
# Integration test: safetensors_weights_iterator with EP filtering
# ---------------------------------------------------------------------------
class TestSafetensorsWeightsIteratorWithEpFilter:
"""Verify that EP filtering produces a strict subset of unfiltered loading
and that all expected dense + local expert weights are present."""
@pytest.fixture(scope="class")
def gpt2_files(self):
"""Download GPT-2 safetensors to a temp dir (shared across class)."""
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf,
)
download_weights_from_hf(
"openai-community/gpt2",
allow_patterns=["*.safetensors"],
cache_dir=tmpdir,
)
files = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(files) > 0
yield files
def test_no_filter_returns_all(self, gpt2_files):
"""With local_expert_ids=None, all weights are returned (no MoE)."""
all_weights = dict(safetensors_weights_iterator(gpt2_files, False))
filtered_weights = dict(
safetensors_weights_iterator(gpt2_files, False, local_expert_ids=None)
)
assert set(all_weights.keys()) == set(filtered_weights.keys())
def test_empty_filter_skips_experts_only(self, gpt2_files):
"""GPT-2 has no expert weights, so even an empty local_expert_ids
set should return all weights (all are dense)."""
all_weights = dict(safetensors_weights_iterator(gpt2_files, False))
filtered_weights = dict(
safetensors_weights_iterator(gpt2_files, False, local_expert_ids=set())
)
# GPT-2 has no experts, so nothing should be filtered
assert set(all_weights.keys()) == set(filtered_weights.keys())
class TestEpFilterOnSyntheticMoeWeights:
"""Create synthetic safetensors files with expert-like naming and verify
that the filter correctly skips non-local experts."""
@pytest.fixture
def synthetic_moe_files(self, tmp_path):
"""Create synthetic safetensors with expert-patterned tensor names."""
from safetensors.torch import save_file
tensors = {}
# Dense weights
tensors["model.embed_tokens.weight"] = torch.randn(100, 64)
tensors["model.layers.0.self_attn.q_proj.weight"] = torch.randn(64, 64)
tensors["model.layers.0.input_layernorm.weight"] = torch.randn(64)
# Expert weights: 8 experts
for expert_id in range(8):
tensors[f"model.layers.0.mlp.experts.{expert_id}.gate_proj.weight"] = (
torch.randn(128, 64)
)
tensors[f"model.layers.0.mlp.experts.{expert_id}.up_proj.weight"] = (
torch.randn(128, 64)
)
tensors[f"model.layers.0.mlp.experts.{expert_id}.down_proj.weight"] = (
torch.randn(64, 128)
)
# Shared expert (should never be filtered)
tensors["model.layers.0.mlp.shared_experts.gate_proj.weight"] = torch.randn(
128, 64
)
filepath = str(tmp_path / "model-00001-of-00001.safetensors")
save_file(tensors, filepath)
return [filepath], tensors
def test_no_filter_returns_all(self, synthetic_moe_files):
files, expected = synthetic_moe_files
loaded = dict(safetensors_weights_iterator(files, False))
assert set(loaded.keys()) == set(expected.keys())
def test_ep2_rank0_gets_half_experts(self, synthetic_moe_files):
files, expected = synthetic_moe_files
# EP=2, rank=0 → experts 0-3
local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=0)
loaded = dict(
safetensors_weights_iterator(files, False, local_expert_ids=local_ids)
)
# Should have all dense + shared + experts 0-3 only
for name in loaded:
eid = parse_expert_id(name)
if eid is not None:
assert eid in local_ids, f"Non-local expert {eid} was loaded"
# Check expert count: 4 experts × 3 weights = 12
expert_names = [n for n in loaded if parse_expert_id(n) is not None]
assert len(expert_names) == 4 * 3
# Check all dense weights present
assert "model.embed_tokens.weight" in loaded
assert "model.layers.0.self_attn.q_proj.weight" in loaded
assert "model.layers.0.input_layernorm.weight" in loaded
assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded
def test_ep2_rank1_gets_other_half(self, synthetic_moe_files):
files, expected = synthetic_moe_files
local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=1)
loaded = dict(
safetensors_weights_iterator(files, False, local_expert_ids=local_ids)
)
expert_names = [n for n in loaded if parse_expert_id(n) is not None]
assert len(expert_names) == 4 * 3
for name in expert_names:
assert parse_expert_id(name) in local_ids
def test_ep8_each_rank_gets_one_expert(self, synthetic_moe_files):
files, _ = synthetic_moe_files
all_expert_names = set()
for rank in range(8):
local_ids = compute_local_expert_ids(8, ep_size=8, ep_rank=rank)
loaded = dict(
safetensors_weights_iterator(files, False, local_expert_ids=local_ids)
)
expert_names = {n for n in loaded if parse_expert_id(n) is not None}
# 1 expert × 3 weights
assert len(expert_names) == 3
all_expert_names |= expert_names
# All 8 experts × 3 weights covered across ranks
assert len(all_expert_names) == 24
def test_tensor_values_match(self, synthetic_moe_files):
"""Filtered tensors have identical values to unfiltered ones."""
files, _ = synthetic_moe_files
all_weights = dict(safetensors_weights_iterator(files, False))
local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=0)
filtered = dict(
safetensors_weights_iterator(files, False, local_expert_ids=local_ids)
)
for name, tensor in filtered.items():
assert torch.equal(tensor, all_weights[name]), f"Tensor mismatch for {name}"

View File

@@ -16,6 +16,9 @@ from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.ep_weight_filter import (
compute_local_expert_ids,
)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
@@ -70,6 +73,7 @@ class DefaultModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self.local_expert_ids: set[int] | None = None
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
@@ -243,6 +247,7 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
local_expert_ids=self.local_expert_ids,
)
else:
if extra_config.get("enable_multithread_load"):
@@ -296,6 +301,58 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns_overrides=None,
)
def _init_ep_weight_filter(self, model_config: ModelConfig) -> None:
"""Compute local expert ids for EP weight filtering.
When expert parallelism is active, each rank only needs a subset of
expert weights. By computing the set upfront we can skip non-local
expert tensors *before* reading them from disk.
"""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
if not (model_config.is_moe and parallel_config.enable_expert_parallel):
return
num_experts = model_config.get_num_experts()
if num_experts <= 0:
return
# EP size/rank computation mirrors FusedMoEParallelConfig.make():
# ep_size = dp_size * pcp_size * tp_size (flattened)
# ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
)
dp_size = parallel_config.data_parallel_size
tp_size = parallel_config.tensor_parallel_size
pcp_size = parallel_config.prefill_context_parallel_size
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_rank = get_tensor_model_parallel_rank() if tp_size > 1 else 0
pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
ep_size = dp_size * pcp_size * tp_size
ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
self.local_expert_ids = compute_local_expert_ids(
num_experts,
ep_size,
ep_rank,
placement=parallel_config.expert_placement_strategy,
)
if self.local_expert_ids is not None:
logger.info_once(
"EP weight filter: ep_size=%d, ep_rank=%d, loading %d/%d experts",
ep_size,
ep_rank,
len(self.local_expert_ids),
num_experts,
)
@instrument(span_name="Load weights")
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao":
@@ -307,6 +364,8 @@ class DefaultModelLoader(BaseModelLoader):
):
self.load_config.safetensors_load_strategy = "torchao"
self._init_ep_weight_filter(model_config)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Filter out non-local expert weights during loading to avoid redundant I/O.
In DP+EP deployments each rank only needs its own expert shard. Skipping
non-local expert tensors *before* they are read from disk eliminates the
majority of storage I/O for MoE models (experts typically account for
~85-90 % of total weight bytes).
"""
import regex as re
# Matches per-expert weight names like ".experts.42.gate_proj.weight".
# Does NOT match 3D fused-expert names like ".experts.gate_proj.weight"
# (no numeric id) — those are intentionally left unfiltered so the full
# tensor is loaded and sliced later by FusedMoE.weight_loader.
_EXPERT_ID_RE = re.compile(r"\.experts\.(\d+)\.")
def parse_expert_id(weight_name: str) -> int | None:
"""Return the expert id embedded in *weight_name*, or ``None`` if it is
not an per-expert weight.
Returns ``None`` for dense weights (attention, layernorm, embedding),
shared experts, and 3D fused-expert tensors where all experts are stored
in a single tensor without a numeric expert id in the name."""
m = _EXPERT_ID_RE.search(weight_name)
return int(m.group(1)) if m else None
def compute_local_expert_ids(
num_experts: int,
ep_size: int,
ep_rank: int,
placement: str = "linear",
) -> set[int] | None:
"""Compute the set of global expert ids owned by *ep_rank*.
Returns ``None`` when EP is not active (``ep_size <= 1``), meaning all
experts are local and no filtering should be performed.
The distribution logic mirrors
:func:`vllm.model_executor.layers.fused_moe.layer.determine_expert_map`.
Args:
placement: ``"linear"`` for contiguous assignment,
``"round_robin"`` for interleaved assignment.
"""
if ep_size <= 1:
return None
if placement == "linear":
base = num_experts // ep_size
remainder = num_experts % ep_size
start = ep_rank * base + min(ep_rank, remainder)
local_count = base + (1 if ep_rank < remainder else 0)
return set(range(start, start + local_count))
elif placement == "round_robin":
return set(range(ep_rank, num_experts, ep_size))
else:
raise ValueError(f"Unknown expert placement strategy: {placement}")
def should_skip_weight(
weight_name: str,
local_expert_ids: set[int] | None,
) -> bool:
"""Return ``True`` if *weight_name* is an expert weight that does not
belong to the local rank and should be skipped during loading."""
if local_expert_ids is None:
return False
eid = parse_expert_id(weight_name)
if eid is None:
# Not an expert weight (dense / shared-expert / embedding) → keep.
return False
return eid not in local_expert_ids

View File

@@ -35,6 +35,9 @@ from vllm.model_executor.layers.quantization import (
QuantizationConfig,
get_quantization_config,
)
from vllm.model_executor.model_loader.ep_weight_filter import (
should_skip_weight,
)
from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.import_utils import PlaceholderModule
@@ -721,8 +724,14 @@ def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
safetensors_load_strategy: str = "lazy",
local_expert_ids: set[int] | None = None,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
"""Iterate over the weights in the model safetensor files.
When *local_expert_ids* is provided, expert weights not belonging to
this rank are skipped **before** reading from disk, which drastically
reduces storage I/O for MoE models under EP.
"""
loading_desc = "Loading safetensors checkpoint shards"
if safetensors_load_strategy == "eager":
loading_desc += " (eager)"
@@ -737,7 +746,9 @@ def safetensors_weights_iterator(
if safetensors_load_strategy == "eager":
with open(st_file, "rb") as f:
state_dict = load(f.read())
yield from state_dict.items()
for name, param in state_dict.items():
if not should_skip_weight(name, local_expert_ids):
yield name, param
elif safetensors_load_strategy == "torchao":
# we can't load flattened torchao tensor subclasses directly into the model
# instead we reconstruct the subclasses here before returning
@@ -753,6 +764,8 @@ def safetensors_weights_iterator(
with safe_open(st_file, framework="pt") as f:
state_dict = {}
for name in f.keys(): # noqa: SIM118
if should_skip_weight(name, local_expert_ids):
continue
state_dict[name] = f.get_tensor(name)
# update with leftover tensor data from previous iteration, if any
@@ -769,6 +782,8 @@ def safetensors_weights_iterator(
else:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
if should_skip_weight(name, local_expert_ids):
continue
param = f.get_tensor(name)
yield name, param