Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,20 +8,30 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
|
||||
make_prepare_finalize, prepare_finalize_info)
|
||||
from .mk_objects import (
|
||||
TestMoEQuantConfig,
|
||||
expert_info,
|
||||
make_fused_experts,
|
||||
make_prepare_finalize,
|
||||
prepare_finalize_info,
|
||||
)
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
|
||||
@@ -94,8 +104,7 @@ class Config:
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
return (not self.is_per_act_token_quant
|
||||
and self.quant_block_shape is None)
|
||||
return not self.is_per_act_token_quant and self.quant_block_shape is None
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
@@ -134,23 +143,24 @@ class Config:
|
||||
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
|
||||
)
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None)
|
||||
return (
|
||||
self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None
|
||||
)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||
info.activation_format)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||
info.activation_format)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
@@ -190,8 +200,10 @@ class Config:
|
||||
|
||||
def needs_deep_ep(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (info.backend == "deepep_high_throughput"
|
||||
or info.backend == "deepep_low_latency")
|
||||
return (
|
||||
info.backend == "deepep_high_throughput"
|
||||
or info.backend == "deepep_low_latency"
|
||||
)
|
||||
|
||||
def all2all_backend(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
@@ -211,20 +223,26 @@ class Config:
|
||||
return False
|
||||
|
||||
# Check quantization sanity
|
||||
if (int(self.is_per_act_token_quant) +
|
||||
int(self.is_per_tensor_act_quant) +
|
||||
int(self.quant_block_shape is not None)) > 1:
|
||||
if (
|
||||
int(self.is_per_act_token_quant)
|
||||
+ int(self.is_per_tensor_act_quant)
|
||||
+ int(self.quant_block_shape is not None)
|
||||
) > 1:
|
||||
# invalid quant config
|
||||
return False
|
||||
|
||||
# check type support
|
||||
if self.quant_dtype is None:
|
||||
if (self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()):
|
||||
if (
|
||||
self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False
|
||||
else:
|
||||
if (self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()):
|
||||
if (
|
||||
self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False
|
||||
|
||||
# Check block quanization support
|
||||
@@ -261,18 +279,21 @@ class WeightTensors:
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Weight Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.w1, "w1")} \n'
|
||||
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
|
||||
s += f" - {_describe_tensor(self.w1, 'w1')} \n"
|
||||
s += f" - {_describe_tensor(self.w2, 'w2')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
|
||||
return s
|
||||
|
||||
def is_quantized(self) -> bool:
|
||||
# or w1_scale is not None?
|
||||
return (self.w1.dtype == torch.float8_e4m3fn
|
||||
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
||||
return (
|
||||
self.w1.dtype == torch.float8_e4m3fn
|
||||
or self.w1.dtype == torch.uint8
|
||||
or self.w1.dtype == torch.int8
|
||||
)
|
||||
|
||||
def to_current_device(self):
|
||||
device = torch.cuda.current_device()
|
||||
@@ -289,16 +310,13 @@ class WeightTensors:
|
||||
if self.w2_gs is not None:
|
||||
self.w2_gs = self.w2_gs.to(device=device)
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
w1_scale = self.w1_scale[
|
||||
s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[
|
||||
s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
|
||||
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
|
||||
|
||||
@@ -313,15 +331,11 @@ class WeightTensors:
|
||||
in_dtype=config.dtype,
|
||||
quant_dtype=config.quant_dtype,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_out_ch_quant=config.
|
||||
is_per_act_token_quant, # or config.is_per_out_ch_quant
|
||||
per_out_ch_quant=config.is_per_act_token_quant, # or config.is_per_out_ch_quant
|
||||
)
|
||||
return WeightTensors(
|
||||
w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_gs=w1_gs,
|
||||
w2_gs=w2_gs)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -336,22 +350,22 @@ class RankTensors:
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
|
||||
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
|
||||
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
|
||||
s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
|
||||
s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
|
||||
s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
config: Config,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = (torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
|
||||
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
@@ -362,36 +376,29 @@ class RankTensors:
|
||||
# first - so further quantize and dequantize will yield the same
|
||||
# values.
|
||||
if config.is_per_tensor_act_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=False)
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
|
||||
return a_q.float().mul(a_scales).to(dtype), a_scales
|
||||
|
||||
if config.is_per_act_token_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a,
|
||||
use_per_token_if_dynamic=True)
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
return a_q.float().mul(a_scales).to(dtype), None
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
block_k = config.quant_block_shape[1]
|
||||
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
|
||||
return a_q.float().view(
|
||||
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
|
||||
return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
|
||||
dtype
|
||||
), None
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config, pgi: ProcessGroupInfo):
|
||||
|
||||
dtype = config.dtype
|
||||
topk, m, _ = (config.topk, config.M, config.K)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
|
||||
config)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
|
||||
|
||||
num_local_experts, global_num_experts = (config.num_local_experts,
|
||||
config.E)
|
||||
score = torch.randn((m, global_num_experts),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
||||
False)
|
||||
num_local_experts, global_num_experts = (config.num_local_experts, config.E)
|
||||
score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
@@ -400,14 +407,15 @@ class RankTensors:
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1 and config.supports_expert_map():
|
||||
expert_map = torch.full((global_num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full(
|
||||
(global_num_experts,), fill_value=-1, dtype=torch.int32
|
||||
)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
expert_map = expert_map.to(
|
||||
device=torch.cuda.current_device(), dtype=torch.int32
|
||||
)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
@@ -418,9 +426,9 @@ class RankTensors:
|
||||
)
|
||||
|
||||
|
||||
def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
rank_tensors: RankTensors) -> torch.Tensor:
|
||||
|
||||
def reference_moe_impl(
|
||||
config: Config, weights: WeightTensors, rank_tensors: RankTensors
|
||||
) -> torch.Tensor:
|
||||
if config.quant_dtype == "nvfp4":
|
||||
quant_blocksize = 16
|
||||
dtype = config.dtype
|
||||
@@ -433,8 +441,10 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
w2_blockscale = weights.w2_scale
|
||||
w2_gs = weights.w2_gs
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
|
||||
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
|
||||
/ torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
@@ -447,14 +457,17 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
assert w2_blockscale.shape[2] % 4 == 0
|
||||
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||
rank_tensors.hidden_states, a_global_scale)
|
||||
rank_tensors.hidden_states, a_global_scale
|
||||
)
|
||||
|
||||
a = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=a_fp4.device,
|
||||
block_size=quant_blocksize)
|
||||
a = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=a_fp4.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
e = w1_q.shape[0]
|
||||
n = w1_q.shape[1] // 2
|
||||
@@ -464,18 +477,22 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w1[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
a_scale = None
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
@@ -493,27 +510,29 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
per_act_token_quant = config.is_per_act_token_quant
|
||||
block_shape = config.quant_block_shape
|
||||
|
||||
return torch_experts(a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1
|
||||
and config.supports_apply_weight_on_input())
|
||||
return torch_experts(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
)
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones((num_experts, ),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float32)
|
||||
return torch.ones(
|
||||
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
@@ -521,12 +540,12 @@ def make_modular_kernel(
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2**math.ceil(math.log2(x))
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
@@ -546,9 +565,9 @@ def make_modular_kernel(
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
||||
config.all2all_backend(), moe,
|
||||
quant_config)
|
||||
prepare_finalize = make_prepare_finalize(
|
||||
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
|
||||
)
|
||||
|
||||
fused_experts = make_fused_experts(
|
||||
config.fused_experts_type,
|
||||
@@ -559,7 +578,8 @@ def make_modular_kernel(
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts
|
||||
)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
@@ -587,10 +607,8 @@ def run_modular_kernel(
|
||||
w1_scale=rank_weights.w1_scale,
|
||||
w2_scale=rank_weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
g1_alphas=(1 / rank_weights.w1_gs)
|
||||
if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs)
|
||||
if rank_weights.w2_gs is not None else None,
|
||||
g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
|
||||
a1_gscale=gscale,
|
||||
a2_gscale=gscale,
|
||||
block_shape=config.quant_block_shape,
|
||||
@@ -603,38 +621,30 @@ def run_modular_kernel(
|
||||
# impls might update the tensor in place
|
||||
hidden_states = rank_tensors.hidden_states.clone()
|
||||
|
||||
topk_ids = rank_tensors.topk_ids.to(
|
||||
mk.prepare_finalize.topk_indices_dtype())
|
||||
topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states":
|
||||
hidden_states,
|
||||
"w1":
|
||||
rank_weights.w1,
|
||||
"w2":
|
||||
rank_weights.w2,
|
||||
"topk_weights":
|
||||
rank_tensors.topk_weights,
|
||||
"topk_ids":
|
||||
topk_ids,
|
||||
"expert_map":
|
||||
rank_tensors.expert_map,
|
||||
"global_num_experts":
|
||||
config.E,
|
||||
"apply_router_weight_on_input":
|
||||
config.topk == 1 and config.supports_apply_weight_on_input(),
|
||||
"hidden_states": hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
}
|
||||
|
||||
num_tokens = rank_tensors.hidden_states.shape[0]
|
||||
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
|
||||
device="cuda",
|
||||
dtype=torch.int)
|
||||
num_tokens_across_dp = torch.tensor(
|
||||
[num_tokens] * config.world_size, device="cuda", dtype=torch.int
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user