[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -75,6 +75,7 @@ def pplx_cutlass_moe(
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
intermediate_dim = w2.shape[2]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = hidden_dim # TODO support more cases
|
||||
device = pgi.device
|
||||
@@ -123,10 +124,31 @@ def pplx_cutlass_moe(
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers)
|
||||
|
||||
ab_strides1 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_local_experts, ),
|
||||
intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_local_experts, ),
|
||||
2 * intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
experts = CutlassExpertsFp8(num_local_experts,
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_batched_format=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user