[Kernel] [Helion] [6/N] Add num_tokens dimension to silu_mul autotuning and dispatching (#34185)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2026-02-20 08:36:51 -08:00
committed by GitHub
parent 6ce80f7071
commit a6d0299c75
3 changed files with 55236 additions and 237 deletions

View File

@@ -54,8 +54,8 @@ def reset_config_manager_singleton():
class TestSiluMulFp8ConfigPicker:
def test_config_picker_exact_match(self):
config_keys = [
"intermediate_2048_batchsize_256",
"intermediate_4096_batchsize_256",
"intermediate_2048_numtokens_256",
"intermediate_4096_numtokens_256",
]
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
@@ -63,12 +63,12 @@ class TestSiluMulFp8ConfigPicker:
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key == "intermediate_2048_batchsize_256"
assert selected_key == "intermediate_2048_numtokens_256"
def test_config_picker_closest_match(self):
config_keys = [
"intermediate_2048_batchsize_256",
"intermediate_4096_batchsize_256",
"intermediate_2048_numtokens_256",
"intermediate_4096_numtokens_256",
]
# Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048
input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda")
@@ -76,10 +76,10 @@ class TestSiluMulFp8ConfigPicker:
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key == "intermediate_4096_batchsize_256"
assert selected_key == "intermediate_4096_numtokens_256"
def test_config_picker_fallback_to_default(self):
config_keys = ["default", "some_other_key"]
config_keys = ["default"]
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
@@ -101,9 +101,9 @@ class TestSiluMulFp8ConfigPicker:
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120])
def test_config_picker_different_sizes(self, intermediate_size):
config_keys = [
"intermediate_2048_batchsize_256",
"intermediate_4096_batchsize_256",
"intermediate_5120_batchsize_256",
"intermediate_2048_numtokens_256",
"intermediate_4096_numtokens_256",
"intermediate_5120_numtokens_256",
]
input_tensor = torch.randn(
@@ -113,9 +113,73 @@ class TestSiluMulFp8ConfigPicker:
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
expected_key = f"intermediate_{intermediate_size}_batchsize_256"
expected_key = f"intermediate_{intermediate_size}_numtokens_256"
assert selected_key == expected_key
def test_config_picker_numtokens_ceiling(self):
"""Pick the smallest numtokens >= input num_tokens."""
config_keys = [
"intermediate_4096_numtokens_8",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
"intermediate_4096_numtokens_256",
]
# 20 tokens -> should pick numtokens_32 (smallest >= 20)
input_tensor = torch.randn(20, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_32"
def test_config_picker_numtokens_exact(self):
"""Exact num_tokens match is preferred over ceiling."""
config_keys = [
"intermediate_4096_numtokens_8",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
]
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_32"
def test_config_picker_numtokens_fallback_to_largest(self):
"""Fall back to the largest numtokens when input exceeds all."""
config_keys = [
"intermediate_4096_numtokens_8",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
]
# 512 tokens -> exceeds all available, should pick largest (128)
input_tensor = torch.randn(512, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_128"
def test_config_picker_malformed_key_raises(self):
"""Malformed config keys should raise ValueError."""
config_keys = ["intermediate_4096_badformat_256"]
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
with pytest.raises(ValueError, match="Malformed config key"):
pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
def test_config_picker_default_ignored_when_valid_keys_exist(self):
"""'default' is skipped in favor of a real match."""
config_keys = [
"default",
"intermediate_4096_numtokens_32",
"intermediate_4096_numtokens_128",
]
input_tensor = torch.randn(64, 8192, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
assert selected_key == "intermediate_4096_numtokens_128"
class TestSiluMulFp8Correctness:
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
from typing import Any
import regex as re
import torch
from vllm.logger import init_logger
@@ -53,44 +54,78 @@ def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return out.view(output_shape)
@silu_mul_fp8.register_input_generator # type: ignore[misc]
def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336]
# Use the same num_tokens values as vLLM's default cudagraph capture sizes.
# See vllm/config/vllm.py _set_cudagraph_sizes() for the canonical formula.
num_tokens_list = [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, 513, 16))
inputs = {}
for num_tokens in num_tokens_list:
for intermediate_size in intermediate_sizes:
# Input tensor has shape (num_tokens, 2 * intermediate_size)
# because silu_mul splits it into two halves
input_tensor = torch.randn(
num_tokens,
2 * intermediate_size,
device="cuda",
dtype=torch.bfloat16,
)
scale = torch.tensor([1.0], device="cuda", dtype=torch.float32)
config_key = f"intermediate_{intermediate_size}_numtokens_{num_tokens}"
inputs[config_key] = (input_tensor, scale)
return inputs
@silu_mul_fp8.register_config_picker # type: ignore[misc]
def pick_silu_mul_fp8_config(
args: tuple[Any, ...], config_keys: list[str]
) -> str | None:
"""Pick the best pre-tuned config for the given input shape.
Selection strategy:
1. Find the closest intermediate_size among available configs
(exact match preferred).
2. Among the num_tokens values tuned for that intermediate_size, pick
the smallest num_tokens >= the input's num_tokens. If the input is
larger than all available num_tokens, fall back to the largest.
Config keys must be "default" or follow the format
"intermediate_{int}_numtokens_{int}".
"""
if not config_keys:
return None
input_tensor, scale = args
input_tensor, _scale = args
intermediate_size = input_tensor.shape[-1] // 2
# TODO(gmagosfm): Rerun autotuning to capture config for
# other batch sizes.
target_key = f"intermediate_{intermediate_size}_batchsize_256"
if target_key in config_keys:
return target_key
intermediate_sizes = []
num_tokens = input_tensor.view(-1, input_tensor.shape[-1]).shape[0]
configs: dict[int, list[int]] = {}
for key in config_keys:
if key.startswith("intermediate_") and "_batchsize_256" in key:
try:
size_str = key.split("_")[1]
size = int(size_str)
intermediate_sizes.append((abs(size - intermediate_size), key))
except (ValueError, IndexError):
continue
if key == "default":
continue
match = re.fullmatch(r"intermediate_(\d+)_numtokens_(\d+)", key)
if not match:
raise ValueError(
f"Malformed config key '{key}', "
f"expected format 'intermediate_{{int}}_numtokens_{{int}}'"
)
isize_str, ntokens_str = match.groups()
configs.setdefault(int(isize_str), []).append(int(ntokens_str))
if intermediate_sizes:
_, best_key = min(intermediate_sizes)
logger.debug(
"No exact config for intermediate_size=%d, using closest match: %s",
intermediate_size,
best_key,
)
return best_key
if "default" in config_keys:
return "default"
if not configs:
return "default" if "default" in config_keys else None
return None
best_isize = min(configs, key=lambda s: abs(s - intermediate_size))
available_ntokens = sorted(configs[best_isize])
best_ntokens = next(
(n for n in available_ntokens if n >= num_tokens), available_ntokens[-1]
)
return f"intermediate_{best_isize}_numtokens_{best_ntokens}"
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: