[Bug] Fix vLLM config is not set error (#29999)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import final
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
@@ -716,6 +716,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
shared_experts_stream: torch.cuda.Stream | None = None,
|
||||
parallel_config: ParallelConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
@@ -723,6 +724,14 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.shared_experts = shared_experts
|
||||
self.shared_experts_stream = shared_experts_stream
|
||||
|
||||
# cache whether this worker is using DP+EP
|
||||
if parallel_config is None:
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
self.is_dp_ep = (
|
||||
parallel_config.data_parallel_size > 1
|
||||
and parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_formats[0]
|
||||
@@ -811,33 +820,27 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
is_forward_context_available()
|
||||
and get_forward_context().attn_metadata is None
|
||||
)
|
||||
if is_profile_run and self.fused_experts.supports_chunking():
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
is_dp_ep = (
|
||||
parallel_config.data_parallel_size > 1
|
||||
and parallel_config.enable_expert_parallel
|
||||
if is_profile_run and self.fused_experts.supports_chunking() and self.is_dp_ep:
|
||||
max_workspace_13, max_workspace_2, max_fused_out_shape = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
buffers.workspace13.get(
|
||||
max_workspace_13, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.workspace2.get(
|
||||
max_workspace_2, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.fused_out.get(
|
||||
max_fused_out_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
if is_dp_ep:
|
||||
max_workspace_13, max_workspace_2, max_fused_out_shape = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
buffers.workspace13.get(
|
||||
max_workspace_13, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.workspace2.get(
|
||||
max_workspace_2, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.fused_out.get(
|
||||
max_fused_out_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
|
||||
|
||||
Reference in New Issue
Block a user