[Bugfix] Fix FP8 MoE EP Weight Loading for ModelOpt Llama4 (#32886)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
baonudesifeizhai
2026-01-23 10:31:48 -05:00
committed by GitHub
parent 7e22309755
commit 1fb648bf10

View File

@@ -51,6 +51,8 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (
@@ -504,7 +506,25 @@ class Llama4Model(LlamaModel):
.flatten()
.to(new_loaded_weight.device)
)
new_loaded_weight = new_loaded_weight[local_expert_indices]
# Workaround for FP8 CPU indexing on older PyTorch:
# https://github.com/vllm-project/vllm/issues/32862
is_fp8_dtype = new_loaded_weight.dtype == (
current_platform.fp8_dtype()
) or (
new_loaded_weight.dtype.is_floating_point
and new_loaded_weight.element_size() == 1
)
if (
new_loaded_weight.device.type == "cpu"
and is_fp8_dtype
and not is_torch_equal_or_newer("2.11.0")
):
# PyTorch < 2.11 doesn't support CPU float8 indexing.
new_loaded_weight = new_loaded_weight.to(torch.float16)[
local_expert_indices
].to(new_loaded_weight.dtype)
else:
new_loaded_weight = new_loaded_weight[local_expert_indices]
expert_id = local_expert_indices[0].item()
else:
# TODO: add EP support for non fused weights