Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,27 +11,39 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||
QuarkLinearMethod, QuarkW4A4MXFP4)
|
||||
QuarkLinearMethod,
|
||||
QuarkW4A4MXFP4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkW4A4MXFp4MoEMethod)
|
||||
QuarkW4A4MXFp4MoEMethod,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
||||
) and current_platform.is_device_capability(100)
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and has_flashinfer())
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and has_flashinfer()
|
||||
)
|
||||
|
||||
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
from flashinfer import (fp4_quantize, mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
from flashinfer import (
|
||||
fp4_quantize,
|
||||
mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
trtllm_fp4_block_scale_moe,
|
||||
)
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
||||
|
||||
@@ -48,21 +60,25 @@ def enable_pickle(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_case', [
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
|
||||
])
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
|
||||
reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.parametrize(
|
||||
"model_case",
|
||||
[
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
if torch.cuda.device_count() < model_case.tp:
|
||||
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
|
||||
f"{torch.cuda.device_count()}")
|
||||
pytest.skip(
|
||||
f"This test requires >={model_case.tp} gpus, got only "
|
||||
f"{torch.cuda.device_count()}"
|
||||
)
|
||||
|
||||
with vllm_runner(model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy") as llm:
|
||||
with vllm_runner(
|
||||
model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy"
|
||||
) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
@@ -72,21 +88,16 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
||||
|
||||
assert isinstance(layer.mlp.experts.quant_method,
|
||||
QuarkW4A4MXFp4MoEMethod)
|
||||
assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod)
|
||||
|
||||
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Today I am in the French Alps and",
|
||||
max_tokens=20)
|
||||
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def swiglu(x,
|
||||
alpha: float = 1.702,
|
||||
beta: float = 1.0,
|
||||
limit: Optional[float] = None):
|
||||
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
@@ -96,24 +107,19 @@ def swiglu(x,
|
||||
return out_glu * (x_linear + beta)
|
||||
|
||||
|
||||
fp4_lookup_table = [
|
||||
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||||
]
|
||||
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
|
||||
|
||||
|
||||
def mxfp4_dequantize(x, scale):
|
||||
assert x.dtype == torch.uint8
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
x_unpacked = torch.zeros(*x.shape[:-1],
|
||||
x.shape[-1] * 2,
|
||||
dtype=torch.int32,
|
||||
device=x.device)
|
||||
x_unpacked = torch.zeros(
|
||||
*x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
|
||||
)
|
||||
x_unpacked[..., 0::2].copy_(x & 0xF)
|
||||
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
|
||||
|
||||
x_float = torch.zeros(x_unpacked.shape,
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
|
||||
for i, val in enumerate(fp4_lookup_table):
|
||||
x_float[x_unpacked == i] = val
|
||||
|
||||
@@ -162,9 +168,10 @@ def reference_moe(
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
||||
|
||||
if act_type == 'mxfp8':
|
||||
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
|
||||
is_sf_swizzled_layout=False)
|
||||
if act_type == "mxfp8":
|
||||
t_quantized, t_scale = mxfp8_quantize(
|
||||
t.to(torch.bfloat16), is_sf_swizzled_layout=False
|
||||
)
|
||||
t = mxfp8_dequantize(t_quantized, t_scale)
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
@@ -221,37 +228,53 @@ def tg_mxfp4_moe(
|
||||
transpose_optimized: bool = False,
|
||||
) -> torch.Tensor:
|
||||
sf_block_size = 32
|
||||
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
|
||||
and w13_weight.shape[1] == intermediate_size * 2
|
||||
and w13_weight.shape[2] == hidden_size // 2)
|
||||
assert (w13_weight_scale.dim() == 3
|
||||
and w13_weight_scale.shape[0] == num_experts
|
||||
and w13_weight_scale.shape[1] == intermediate_size * 2
|
||||
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
|
||||
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
|
||||
and w2_weight.shape[1] == hidden_size
|
||||
and w2_weight.shape[2] == intermediate_size // 2)
|
||||
assert (w2_weight_scale.dim() == 3
|
||||
and w2_weight_scale.shape[1] == hidden_size
|
||||
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
|
||||
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2)
|
||||
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size)
|
||||
assert (
|
||||
w13_weight.dim() == 3
|
||||
and w13_weight.shape[0] == num_experts
|
||||
and w13_weight.shape[1] == intermediate_size * 2
|
||||
and w13_weight.shape[2] == hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
w13_weight_scale.dim() == 3
|
||||
and w13_weight_scale.shape[0] == num_experts
|
||||
and w13_weight_scale.shape[1] == intermediate_size * 2
|
||||
and w13_weight_scale.shape[2] == hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w2_weight.dim() == 3
|
||||
and w2_weight.shape[0] == num_experts
|
||||
and w2_weight.shape[1] == hidden_size
|
||||
and w2_weight.shape[2] == intermediate_size // 2
|
||||
)
|
||||
assert (
|
||||
w2_weight_scale.dim() == 3
|
||||
and w2_weight_scale.shape[1] == hidden_size
|
||||
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w13_bias.dim() == 2
|
||||
and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2
|
||||
)
|
||||
assert (
|
||||
w2_bias.dim() == 2
|
||||
and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size
|
||||
)
|
||||
|
||||
# Swap w1 and w3 as the definition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
w13_weight_scale_ = w13_weight_scale.clone()
|
||||
w13_weight_ = w13_weight.clone()
|
||||
w13_bias_ = w13_bias.clone()
|
||||
w13_weight[:, :intermediate_size, :].copy_(
|
||||
w13_weight_[:, intermediate_size:, :])
|
||||
w13_weight[:, intermediate_size:, :].copy_(
|
||||
w13_weight_[:, :intermediate_size, :])
|
||||
w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
|
||||
w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
|
||||
w13_weight_scale[:, :intermediate_size, :].copy_(
|
||||
w13_weight_scale_[:, intermediate_size:, :])
|
||||
w13_weight_scale_[:, intermediate_size:, :]
|
||||
)
|
||||
w13_weight_scale[:, intermediate_size:, :].copy_(
|
||||
w13_weight_scale_[:, :intermediate_size, :])
|
||||
w13_weight_scale_[:, :intermediate_size, :]
|
||||
)
|
||||
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
|
||||
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
|
||||
|
||||
@@ -261,18 +284,23 @@ def tg_mxfp4_moe(
|
||||
w13_bias_interleaved = []
|
||||
for i in range(num_experts):
|
||||
w13_weight_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
|
||||
reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
|
||||
)
|
||||
w13_weight_scale_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
|
||||
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
|
||||
)
|
||||
w13_bias_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
|
||||
1)))
|
||||
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
|
||||
)
|
||||
w13_weight = torch.stack(w13_weight_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2)
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||
)
|
||||
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 32)
|
||||
num_experts, 2 * intermediate_size, hidden_size // 32
|
||||
)
|
||||
w13_bias = torch.stack(w13_bias_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size)
|
||||
num_experts, 2 * intermediate_size
|
||||
)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_shuffled = []
|
||||
@@ -291,9 +319,11 @@ def tg_mxfp4_moe(
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled.append(w13_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w13_weight.device)].contiguous())
|
||||
gemm1_weights_shuffled.append(
|
||||
w13_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
@@ -302,26 +332,35 @@ def tg_mxfp4_moe(
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w13_weight_scale.device)].contiguous()))
|
||||
nvfp4_block_scale_interleave(
|
||||
w13_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
|
||||
-1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous())
|
||||
gemm1_bias_shuffled.append(
|
||||
w13_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled.append(w2_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w2_weight.device)].contiguous())
|
||||
gemm2_weights_shuffled.append(
|
||||
w2_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
@@ -330,48 +369,65 @@ def tg_mxfp4_moe(
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w2_weight_scale.device)].contiguous()))
|
||||
nvfp4_block_scale_interleave(
|
||||
w2_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w2 bias shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
|
||||
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
|
||||
gemm2_bias_shuffled.append(
|
||||
w2_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
else:
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_sf_a(
|
||||
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_sf_a(
|
||||
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_shuffled)
|
||||
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
|
||||
num_experts, 2 * intermediate_size,
|
||||
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
|
||||
w13_weight_scale = (
|
||||
torch.stack(gemm1_scales_shuffled)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_shuffled)
|
||||
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
|
||||
num_experts, hidden_size,
|
||||
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
|
||||
w2_weight_scale = (
|
||||
torch.stack(gemm2_scales_shuffled)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
tg_result = trtllm_fp4_block_scale_moe(
|
||||
@@ -401,7 +457,8 @@ def tg_mxfp4_moe(
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
|
||||
routing_method_type=1, # renormalize
|
||||
do_finalize=True)[0]
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
return tg_result
|
||||
|
||||
|
||||
@@ -424,20 +481,21 @@ def check_accuracy(a, b, atol, rtol, percent):
|
||||
if mismatch_percent > 1 - percent:
|
||||
raise Exception(
|
||||
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
||||
f"(threshold: {1-percent:.4f})")
|
||||
f"(threshold: {1 - percent:.4f})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
|
||||
@pytest.mark.parametrize("transpose_optimized", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test")
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test",
|
||||
)
|
||||
def test_trtllm_gen_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
@@ -452,45 +510,52 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
):
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16)
|
||||
w13 = (torch.randn(num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16))
|
||||
w2 = (torch.randn(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16))
|
||||
bias13 = torch.randn(num_experts, intermediate_size * 2,
|
||||
device="cuda:0") * 10
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
|
||||
)
|
||||
w13 = torch.randn(
|
||||
num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
|
||||
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
|
||||
router_logits = torch.rand(num_tokens, num_experts,
|
||||
dtype=torch.float32).cuda()
|
||||
router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()
|
||||
|
||||
w13, w13_scale = fp4_quantize(w13,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False)
|
||||
w13, w13_scale = fp4_quantize(
|
||||
w13,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, intermediate_size * 2, hidden_size // 32)
|
||||
w2, w2_scale = fp4_quantize(w2,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False)
|
||||
num_experts, intermediate_size * 2, hidden_size // 32
|
||||
)
|
||||
w2, w2_scale = fp4_quantize(
|
||||
w2,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 32)
|
||||
if act_type == 'mxfp8':
|
||||
num_experts, hidden_size, intermediate_size // 32
|
||||
)
|
||||
if act_type == "mxfp8":
|
||||
hidden_states, hidden_states_scale = mxfp8_quantize(
|
||||
hidden_states, is_sf_swizzled_layout=False)
|
||||
hidden_states_scale = hidden_states_scale.view(
|
||||
torch.float8_e4m3fn).reshape(-1)
|
||||
hidden_states, is_sf_swizzled_layout=False
|
||||
)
|
||||
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
else:
|
||||
hidden_states_scale = None
|
||||
|
||||
@@ -500,9 +565,10 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
|
||||
bias13_ref = bias13
|
||||
bias2_ref = bias2
|
||||
if act_type == 'mxfp8':
|
||||
hidden_states_ref = mxfp8_dequantize(
|
||||
hidden_states, hidden_states_scale).to(torch.float32)
|
||||
if act_type == "mxfp8":
|
||||
hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
|
||||
torch.float32
|
||||
)
|
||||
else:
|
||||
hidden_states_ref = hidden_states.to(torch.float32)
|
||||
# Process tokens in chunks of 32 to reduce memory usage
|
||||
@@ -529,29 +595,31 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
|
||||
# trtllm-gen result
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
tg_result = tg_mxfp4_moe(router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13,
|
||||
w13_scale,
|
||||
bias13,
|
||||
w2,
|
||||
w2_scale,
|
||||
bias2,
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit,
|
||||
transpose_optimized=transpose_optimized)
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
tg_result = tg_mxfp4_moe(
|
||||
router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13,
|
||||
w13_scale,
|
||||
bias13,
|
||||
w2,
|
||||
w2_scale,
|
||||
bias2,
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit,
|
||||
transpose_optimized=transpose_optimized,
|
||||
)
|
||||
# relatively loose check since the mxfp4 quantization is less accurate
|
||||
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
@@ -573,8 +641,7 @@ def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not HOPPER_MXFP4_BF16_AVAILABLE,
|
||||
reason="nvidia gpu sm90 and flashinfer are required for this test",
|
||||
@@ -593,52 +660,73 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
|
||||
w13_q = torch.randint(
|
||||
0,
|
||||
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
|
||||
256,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w13_scale = torch.randint(
|
||||
118,
|
||||
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
|
||||
123,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
w2_q = torch.randint(0,
|
||||
256,
|
||||
(num_experts, hidden_size, intermediate_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
w2_q = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(num_experts, hidden_size, intermediate_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w2_scale = torch.randint(
|
||||
118,
|
||||
123, (num_experts, hidden_size, intermediate_size // 32),
|
||||
123,
|
||||
(num_experts, hidden_size, intermediate_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) * 10)
|
||||
bias2 = (torch.randn(
|
||||
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
||||
router_logits = torch.rand(num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size)
|
||||
num_experts, 2 * intermediate_size, hidden_size
|
||||
)
|
||||
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
|
||||
num_experts, hidden_size, intermediate_size)
|
||||
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
||||
hidden_states.to(torch.float32), w13_ref,
|
||||
bias13.to(torch.float32), w2_ref,
|
||||
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
|
||||
num_experts, hidden_size, intermediate_size
|
||||
)
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"bf16",
|
||||
)
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
@@ -654,23 +742,24 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
|
||||
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
|
||||
|
||||
routing_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
||||
topk,
|
||||
dim=-1)
|
||||
token_final_scales = (token_final_scales /
|
||||
token_final_scales.sum(dim=-1, keepdim=True))
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states,
|
||||
@@ -680,8 +769,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
fc2_expert_weights=w2_q,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=[w13_s_inter.to(torch.uint8),
|
||||
w2_s_inter.to(torch.uint8)],
|
||||
quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha,
|
||||
@@ -702,11 +790,13 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100) and has_flashinfer()),
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and has_flashinfer()
|
||||
),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
@@ -723,32 +813,43 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Float weights in w13 format [w1; w3]
|
||||
w13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) / 10)
|
||||
w2 = (torch.randn(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) / 10)
|
||||
w13 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) * 10)
|
||||
bias2 = (torch.randn(
|
||||
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
||||
router_logits = torch.rand(num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Quantize weights to MXFP4 per expert (SM100 path)
|
||||
from flashinfer import mxfp4_quantize
|
||||
@@ -761,36 +862,56 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
sfs.append(sf)
|
||||
return torch.stack(qs), torch.stack(sfs)
|
||||
|
||||
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
|
||||
scale_tensor: torch.Tensor):
|
||||
def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
|
||||
num_batches = mat_fp4.size(0)
|
||||
scale_tensor = scale_tensor.view(num_batches, -1)
|
||||
from flashinfer import mxfp4_dequantize
|
||||
return torch.stack([
|
||||
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
||||
for b in range(num_batches)
|
||||
])
|
||||
|
||||
return torch.stack(
|
||||
[
|
||||
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
||||
for b in range(num_batches)
|
||||
]
|
||||
)
|
||||
|
||||
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
|
||||
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
|
||||
|
||||
# Reference result using dequantized tensors and reference_moe
|
||||
w13_ref = dequant_mxfp4_batches(
|
||||
w13_q.view(torch.uint8),
|
||||
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size).to(device)
|
||||
w2_ref = dequant_mxfp4_batches(
|
||||
w2_q.view(torch.uint8),
|
||||
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||
num_experts, hidden_size, intermediate_size).to(device)
|
||||
w13_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size)
|
||||
.to(device)
|
||||
)
|
||||
w2_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, hidden_size, intermediate_size)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
# Quantize activations for SM100 path and dequantize for reference
|
||||
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
||||
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
|
||||
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
||||
hidden_states.to(torch.float32), w13_ref,
|
||||
bias13.to(torch.float32), w2_ref,
|
||||
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"mxfp8",
|
||||
)
|
||||
|
||||
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
@@ -807,31 +928,28 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
# Build routing for kernel
|
||||
routing_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
||||
topk,
|
||||
dim=-1)
|
||||
token_final_scales = (token_final_scales /
|
||||
token_final_scales.sum(dim=-1, keepdim=True))
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha_t = torch.full((num_experts, ),
|
||||
alpha,
|
||||
device=hidden_states.device)
|
||||
alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
else:
|
||||
alpha_t = None
|
||||
if beta is not None:
|
||||
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
else:
|
||||
beta_t = None
|
||||
if limit is not None:
|
||||
limit_t = torch.full((num_experts, ),
|
||||
limit,
|
||||
device=hidden_states.device)
|
||||
limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
else:
|
||||
limit_t = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user