[Bugfix] Fix FP8 MoE EP Weight Loading for ModelOpt Llama4 (#32886)
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
@@ -51,6 +51,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
||||
from .utils import (
|
||||
@@ -504,7 +506,25 @@ class Llama4Model(LlamaModel):
|
||||
.flatten()
|
||||
.to(new_loaded_weight.device)
|
||||
)
|
||||
new_loaded_weight = new_loaded_weight[local_expert_indices]
|
||||
# Workaround for FP8 CPU indexing on older PyTorch:
|
||||
# https://github.com/vllm-project/vllm/issues/32862
|
||||
is_fp8_dtype = new_loaded_weight.dtype == (
|
||||
current_platform.fp8_dtype()
|
||||
) or (
|
||||
new_loaded_weight.dtype.is_floating_point
|
||||
and new_loaded_weight.element_size() == 1
|
||||
)
|
||||
if (
|
||||
new_loaded_weight.device.type == "cpu"
|
||||
and is_fp8_dtype
|
||||
and not is_torch_equal_or_newer("2.11.0")
|
||||
):
|
||||
# PyTorch < 2.11 doesn't support CPU float8 indexing.
|
||||
new_loaded_weight = new_loaded_weight.to(torch.float16)[
|
||||
local_expert_indices
|
||||
].to(new_loaded_weight.dtype)
|
||||
else:
|
||||
new_loaded_weight = new_loaded_weight[local_expert_indices]
|
||||
expert_id = local_expert_indices[0].item()
|
||||
else:
|
||||
# TODO: add EP support for non fused weights
|
||||
|
||||
Reference in New Issue
Block a user