[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
""" CUTLASS based Fused MoE kernels."""
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,11 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
_resize_cache,
|
||||
extract_required_args)
|
||||
_resize_cache)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -213,19 +212,14 @@ def run_cutlass_moe_fp8(
|
||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
||||
|
||||
|
||||
# TODO (bnell): split class batched vs. non-batched?
|
||||
# maybe remove need for passing aq to workspace_shapes
|
||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
num_dispatchers: Optional[int] = None,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
@@ -234,33 +228,84 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert max_experts_per_worker > 0
|
||||
assert not use_batched_format or num_dispatchers is not None
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.num_dispatchers = num_dispatchers
|
||||
self.out_dtype = out_dtype
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
expert_num_tokens = None
|
||||
if expert_tokens_meta is not None:
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||
|
||||
use_batched_format = self.activation_formats[
|
||||
0] == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
use_batched_format)
|
||||
|
||||
|
||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.use_batched_format:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
else:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@@ -274,54 +319,69 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.size(1)
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
(N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, N // 2)
|
||||
output = (M * topk, K)
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, N // 2)
|
||||
output = (M * topk, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
expert_num_tokens = None
|
||||
if expert_tokens_meta is not None:
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
|
||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||
def __init__(
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
num_dispatchers: int,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape,
|
||||
)
|
||||
assert max_experts_per_worker > 0
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
self.use_batched_format)
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
padded_M = aq.size(1)
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
@@ -387,11 +447,9 @@ def cutlass_moe_fp8(
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
max_experts_per_worker=num_experts,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -476,8 +534,9 @@ def run_cutlass_moe_fp4(
|
||||
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
||||
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
||||
|
||||
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
|
||||
" between weights.")
|
||||
assert (e_w1 == e_w2
|
||||
and e_w1 == e), ("Number of experts must match",
|
||||
f" between weights. {e_w1}, {e_w2}, {e}")
|
||||
assert (k_a == half_k_w1 * 2
|
||||
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
||||
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
|
||||
@@ -554,6 +613,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
@@ -562,8 +625,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
# NVFP4 requires two levels of quantization, which involves
|
||||
# computing some scaling factors dynamically. This makes it
|
||||
# incompatible with the typical prepare -> MoE -> finalize
|
||||
# pipeline. Move the quantization logic into the MoE body.
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.uint8,
|
||||
quant_dtype=None, # skip quantization in prepare/finalize
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
@@ -572,6 +639,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.out_dtype = out_dtype
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
# TODO(bnell): put this stuff into quant config?
|
||||
self.g1_alphas = g1_alphas
|
||||
self.g2_alphas = g2_alphas
|
||||
self.a1_gscale = a1_gscale
|
||||
self.a2_gscale = a2_gscale
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
@@ -590,8 +663,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@@ -620,34 +692,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
required_keys = [
|
||||
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
|
||||
"e", "device"
|
||||
]
|
||||
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
|
||||
device) = extract_required_args(extra_expert_args, required_keys)
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: torch.Tensor,
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
|
||||
n = w2.shape[2] * 2
|
||||
|
||||
run_cutlass_moe_fp4(
|
||||
output=output,
|
||||
a=hidden_states,
|
||||
a1_gscale=a1_gscale,
|
||||
a1_gscale=self.a1_gscale,
|
||||
w1_fp4=w1,
|
||||
w1_blockscale=w1_scale,
|
||||
w1_alphas=g1_alphas,
|
||||
a2_gscale=a2_gscale,
|
||||
w1_alphas=self.g1_alphas,
|
||||
a2_gscale=self.a2_gscale,
|
||||
w2_fp4=w2,
|
||||
w2_blockscale=w2_scale,
|
||||
w2_alphas=g2_alphas,
|
||||
w2_alphas=self.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
workspace13=workspace13,
|
||||
@@ -656,7 +736,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
device=device,
|
||||
device=hidden_states.device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@@ -677,7 +757,6 @@ def cutlass_moe_fp4(
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False) -> torch.Tensor:
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
@@ -686,6 +765,10 @@ def cutlass_moe_fp4(
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp4(
|
||||
g1_alphas,
|
||||
g2_alphas,
|
||||
a1_gscale,
|
||||
a2_gscale,
|
||||
max_experts_per_worker=e,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=False,
|
||||
@@ -693,29 +776,7 @@ def cutlass_moe_fp4(
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
extra_expert_args = {
|
||||
'g1_alphas': g1_alphas,
|
||||
'g2_alphas': g2_alphas,
|
||||
'a1_gscale': a1_gscale,
|
||||
'a2_gscale': a2_gscale,
|
||||
'm': m,
|
||||
'n': n,
|
||||
'k': k,
|
||||
'e': e,
|
||||
'device': device,
|
||||
}
|
||||
|
||||
# NVFP4 requires two levels of quantization, which involves computing some
|
||||
# scaling factors dynamically. This makes it incompatible with the typical
|
||||
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
|
||||
# MoE body.
|
||||
extra_prepare_args = {
|
||||
'skip_quant': True,
|
||||
}
|
||||
# Similar reason as above.
|
||||
extra_finalize_args = {
|
||||
'skip_weight_reduce': True,
|
||||
}
|
||||
return fn(
|
||||
hidden_states=a,
|
||||
w1=w1_fp4,
|
||||
@@ -731,9 +792,6 @@ def cutlass_moe_fp4(
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args,
|
||||
extra_prepare_args=extra_prepare_args,
|
||||
extra_finalize_args=extra_finalize_args,
|
||||
)
|
||||
|
||||
|
||||
@@ -824,16 +882,6 @@ def run_cutlass_block_scaled_fused_experts(
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
expert_offsets = torch.empty((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
problem_sizes1 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
problem_sizes2 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
a_q, a1_scale = _fp8_quantize(a,
|
||||
@@ -842,6 +890,16 @@ def run_cutlass_block_scaled_fused_experts(
|
||||
block_shape=[128, 128])
|
||||
device = a_q.device
|
||||
|
||||
expert_offsets = torch.empty((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes2 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user