[Kernel] [Helion] [6/N] Add num_tokens dimension to silu_mul autotuning and dispatching (#34185)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
@@ -54,8 +54,8 @@ def reset_config_manager_singleton():
|
||||
class TestSiluMulFp8ConfigPicker:
|
||||
def test_config_picker_exact_match(self):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
"intermediate_2048_numtokens_256",
|
||||
"intermediate_4096_numtokens_256",
|
||||
]
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
@@ -63,12 +63,12 @@ class TestSiluMulFp8ConfigPicker:
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "intermediate_2048_batchsize_256"
|
||||
assert selected_key == "intermediate_2048_numtokens_256"
|
||||
|
||||
def test_config_picker_closest_match(self):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
"intermediate_2048_numtokens_256",
|
||||
"intermediate_4096_numtokens_256",
|
||||
]
|
||||
# Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048
|
||||
input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda")
|
||||
@@ -76,10 +76,10 @@ class TestSiluMulFp8ConfigPicker:
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "intermediate_4096_batchsize_256"
|
||||
assert selected_key == "intermediate_4096_numtokens_256"
|
||||
|
||||
def test_config_picker_fallback_to_default(self):
|
||||
config_keys = ["default", "some_other_key"]
|
||||
config_keys = ["default"]
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
@@ -101,9 +101,9 @@ class TestSiluMulFp8ConfigPicker:
|
||||
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120])
|
||||
def test_config_picker_different_sizes(self, intermediate_size):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
"intermediate_5120_batchsize_256",
|
||||
"intermediate_2048_numtokens_256",
|
||||
"intermediate_4096_numtokens_256",
|
||||
"intermediate_5120_numtokens_256",
|
||||
]
|
||||
|
||||
input_tensor = torch.randn(
|
||||
@@ -113,9 +113,73 @@ class TestSiluMulFp8ConfigPicker:
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
expected_key = f"intermediate_{intermediate_size}_batchsize_256"
|
||||
expected_key = f"intermediate_{intermediate_size}_numtokens_256"
|
||||
assert selected_key == expected_key
|
||||
|
||||
def test_config_picker_numtokens_ceiling(self):
|
||||
"""Pick the smallest numtokens >= input num_tokens."""
|
||||
config_keys = [
|
||||
"intermediate_4096_numtokens_8",
|
||||
"intermediate_4096_numtokens_32",
|
||||
"intermediate_4096_numtokens_128",
|
||||
"intermediate_4096_numtokens_256",
|
||||
]
|
||||
# 20 tokens -> should pick numtokens_32 (smallest >= 20)
|
||||
input_tensor = torch.randn(20, 8192, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
||||
assert selected_key == "intermediate_4096_numtokens_32"
|
||||
|
||||
def test_config_picker_numtokens_exact(self):
|
||||
"""Exact num_tokens match is preferred over ceiling."""
|
||||
config_keys = [
|
||||
"intermediate_4096_numtokens_8",
|
||||
"intermediate_4096_numtokens_32",
|
||||
"intermediate_4096_numtokens_128",
|
||||
]
|
||||
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
||||
assert selected_key == "intermediate_4096_numtokens_32"
|
||||
|
||||
def test_config_picker_numtokens_fallback_to_largest(self):
|
||||
"""Fall back to the largest numtokens when input exceeds all."""
|
||||
config_keys = [
|
||||
"intermediate_4096_numtokens_8",
|
||||
"intermediate_4096_numtokens_32",
|
||||
"intermediate_4096_numtokens_128",
|
||||
]
|
||||
# 512 tokens -> exceeds all available, should pick largest (128)
|
||||
input_tensor = torch.randn(512, 8192, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
||||
assert selected_key == "intermediate_4096_numtokens_128"
|
||||
|
||||
def test_config_picker_malformed_key_raises(self):
|
||||
"""Malformed config keys should raise ValueError."""
|
||||
config_keys = ["intermediate_4096_badformat_256"]
|
||||
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
with pytest.raises(ValueError, match="Malformed config key"):
|
||||
pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
||||
|
||||
def test_config_picker_default_ignored_when_valid_keys_exist(self):
|
||||
"""'default' is skipped in favor of a real match."""
|
||||
config_keys = [
|
||||
"default",
|
||||
"intermediate_4096_numtokens_32",
|
||||
"intermediate_4096_numtokens_128",
|
||||
]
|
||||
input_tensor = torch.randn(64, 8192, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
||||
assert selected_key == "intermediate_4096_numtokens_128"
|
||||
|
||||
|
||||
class TestSiluMulFp8Correctness:
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@@ -53,44 +54,78 @@ def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
return out.view(output_shape)
|
||||
|
||||
|
||||
@silu_mul_fp8.register_input_generator # type: ignore[misc]
|
||||
def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
|
||||
intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336]
|
||||
|
||||
# Use the same num_tokens values as vLLM's default cudagraph capture sizes.
|
||||
# See vllm/config/vllm.py _set_cudagraph_sizes() for the canonical formula.
|
||||
num_tokens_list = [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, 513, 16))
|
||||
|
||||
inputs = {}
|
||||
for num_tokens in num_tokens_list:
|
||||
for intermediate_size in intermediate_sizes:
|
||||
# Input tensor has shape (num_tokens, 2 * intermediate_size)
|
||||
# because silu_mul splits it into two halves
|
||||
input_tensor = torch.randn(
|
||||
num_tokens,
|
||||
2 * intermediate_size,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
scale = torch.tensor([1.0], device="cuda", dtype=torch.float32)
|
||||
|
||||
config_key = f"intermediate_{intermediate_size}_numtokens_{num_tokens}"
|
||||
inputs[config_key] = (input_tensor, scale)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@silu_mul_fp8.register_config_picker # type: ignore[misc]
|
||||
def pick_silu_mul_fp8_config(
|
||||
args: tuple[Any, ...], config_keys: list[str]
|
||||
) -> str | None:
|
||||
"""Pick the best pre-tuned config for the given input shape.
|
||||
|
||||
Selection strategy:
|
||||
1. Find the closest intermediate_size among available configs
|
||||
(exact match preferred).
|
||||
2. Among the num_tokens values tuned for that intermediate_size, pick
|
||||
the smallest num_tokens >= the input's num_tokens. If the input is
|
||||
larger than all available num_tokens, fall back to the largest.
|
||||
|
||||
Config keys must be "default" or follow the format
|
||||
"intermediate_{int}_numtokens_{int}".
|
||||
"""
|
||||
if not config_keys:
|
||||
return None
|
||||
|
||||
input_tensor, scale = args
|
||||
input_tensor, _scale = args
|
||||
intermediate_size = input_tensor.shape[-1] // 2
|
||||
|
||||
# TODO(gmagosfm): Rerun autotuning to capture config for
|
||||
# other batch sizes.
|
||||
target_key = f"intermediate_{intermediate_size}_batchsize_256"
|
||||
if target_key in config_keys:
|
||||
return target_key
|
||||
|
||||
intermediate_sizes = []
|
||||
num_tokens = input_tensor.view(-1, input_tensor.shape[-1]).shape[0]
|
||||
configs: dict[int, list[int]] = {}
|
||||
for key in config_keys:
|
||||
if key.startswith("intermediate_") and "_batchsize_256" in key:
|
||||
try:
|
||||
size_str = key.split("_")[1]
|
||||
size = int(size_str)
|
||||
intermediate_sizes.append((abs(size - intermediate_size), key))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
if key == "default":
|
||||
continue
|
||||
match = re.fullmatch(r"intermediate_(\d+)_numtokens_(\d+)", key)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Malformed config key '{key}', "
|
||||
f"expected format 'intermediate_{{int}}_numtokens_{{int}}'"
|
||||
)
|
||||
isize_str, ntokens_str = match.groups()
|
||||
configs.setdefault(int(isize_str), []).append(int(ntokens_str))
|
||||
|
||||
if intermediate_sizes:
|
||||
_, best_key = min(intermediate_sizes)
|
||||
logger.debug(
|
||||
"No exact config for intermediate_size=%d, using closest match: %s",
|
||||
intermediate_size,
|
||||
best_key,
|
||||
)
|
||||
return best_key
|
||||
if "default" in config_keys:
|
||||
return "default"
|
||||
if not configs:
|
||||
return "default" if "default" in config_keys else None
|
||||
|
||||
return None
|
||||
best_isize = min(configs, key=lambda s: abs(s - intermediate_size))
|
||||
available_ntokens = sorted(configs[best_isize])
|
||||
best_ntokens = next(
|
||||
(n for n in available_ntokens if n >= num_tokens), available_ntokens[-1]
|
||||
)
|
||||
|
||||
return f"intermediate_{best_isize}_numtokens_{best_ntokens}"
|
||||
|
||||
|
||||
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user