diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index a8d810244..1bd83f08b 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -319,6 +319,13 @@ class DefaultModelLoader(BaseModelLoader): and parallel_config.enable_ep_weight_filter ): return + + # When EPLB is enabled, redundant physical expert slots may map to + # logical experts that belong to other ranks in the default partition. + # The weight loader needs to see ALL logical expert weights so it can + # populate these redundant slots. Skip the filter entirely. + if parallel_config.enable_eplb: + return num_experts = model_config.get_num_experts() if num_experts <= 0: diff --git a/vllm/model_executor/model_loader/ep_weight_filter.py b/vllm/model_executor/model_loader/ep_weight_filter.py index 1ef7f0174..190842379 100644 --- a/vllm/model_executor/model_loader/ep_weight_filter.py +++ b/vllm/model_executor/model_loader/ep_weight_filter.py @@ -73,4 +73,9 @@ def should_skip_weight( if eid is None: # Not an expert weight (dense / shared-expert / embedding) → keep. return False + # Only skip heavy weight tensors, never scale/metadata tensors. + # Scale tensors are tiny and some backends need them from ALL experts + # (e.g. FlashInfer NVFP4 computes a global max of activation scales). + if not weight_name.endswith(".weight"): + return False return eid not in local_expert_ids