[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
@@ -112,18 +113,21 @@ def pplx_cutlass_moe(
|
||||
w2_scale = w2_scale.to(device)
|
||||
a1_scale = a1_scale.to(device)
|
||||
|
||||
assert num_experts % world_size == 0
|
||||
num_local_experts = cdiv(num_experts, world_size)
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
ata,
|
||||
max_num_tokens,
|
||||
pgi.world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
)
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers)
|
||||
|
||||
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
|
||||
experts = CutlassExpertsFp8(num_local_experts,
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_batched_format=True)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
@@ -181,35 +185,40 @@ def _pplx_moe(
|
||||
per_out_ch: bool,
|
||||
use_internode: bool,
|
||||
):
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
try:
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks,
|
||||
backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
|
||||
topk_ids)
|
||||
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
||||
w2_scale, topk_weights, topk_ids,
|
||||
a1_scale, out_dtype, per_act_token,
|
||||
per_out_ch, group_name)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_experts(a_full, w1_full, w2_full,
|
||||
topk_weights, topk_ids)
|
||||
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
||||
w2_scale, topk_weights, topk_ids,
|
||||
a1_scale, out_dtype, per_act_token,
|
||||
per_out_ch, group_name)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_output.device)
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_output.device)
|
||||
|
||||
# Uncomment if more debugging is needed
|
||||
# print("PPLX OUT:", pplx_output)
|
||||
# print("TORCH OUT:", torch_output)
|
||||
# Uncomment if more debugging is needed
|
||||
# print("PPLX OUT:", pplx_output)
|
||||
# print("TORCH OUT:", torch_output)
|
||||
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
|
||||
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
torch.testing.assert_close(pplx_output,
|
||||
torch_output,
|
||||
atol=0.05,
|
||||
rtol=0)
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 224])
|
||||
|
||||
Reference in New Issue
Block a user