Compare commits

..

7 Commits

Author SHA1 Message Date
khluu
262ddd0d81 [cherry-pick][Bugfix] Fix EP weight filter breaking EPLB and NVFP4 accuracy #37322
Signed-off-by: khluu <khluu000@gmail.com>
2026-03-18 01:48:32 -07:00
Li, Jiang
e60c1674b3 [Bugfix] Avoid OpenMP thread reallocation in CPU torch compile (#37391)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
(cherry picked from commit 261801242f)
2026-03-18 01:41:42 -07:00
Roy Wang
faa80947f5 [Performance] Add --enable-ep-weight-filter CLI option (#37351)
Signed-off-by: esmeetu <jasonailu87@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
(cherry picked from commit 761e0aa7a0)
2026-03-18 01:41:25 -07:00
Terry Gao
eeabf740bb [Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)
Signed-off-by: tianrengao <terrygao87@gmail.com>
(cherry picked from commit 3e6a1e1686)
2026-03-18 01:41:09 -07:00
Elvir Crnčević
cdcffafef8 Fix eplb nvfp4 experts hook (#37217)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Elvir Crncevic <elvir@anthropic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
(cherry picked from commit fd4d96302a)
2026-03-18 01:40:57 -07:00
Walter Beller-Morales
4d22667c32 [Feature][Frontend] add support for Cohere Embed v2 API (#37074)
Signed-off-by: walterbm <walter.beller.morales@gmail.com>
(cherry picked from commit 061980c36a)
2026-03-16 22:05:47 -07:00
Andreas Karatzas
1fe3932c8b [ROCm] Fix AttributeError for torch.compiler.skip_all_guards_unsafe on older PyTorch (#37219)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
(cherry picked from commit 54a62a79f7)
2026-03-16 21:03:49 -07:00
87 changed files with 396 additions and 3062 deletions

View File

@@ -333,15 +333,15 @@ apply_rocm_test_overrides() {
# --- Entrypoint ignores ---
if [[ $cmds == *" entrypoints/openai "* ]]; then
cmds=${cmds//" entrypoints/openai "/" entrypoints/openai \
--ignore=entrypoints/openai/chat_completion/test_audio.py \
--ignore=entrypoints/openai/completion/test_shutdown.py \
--ignore=entrypoints/openai/test_audio.py \
--ignore=entrypoints/openai/test_shutdown.py \
--ignore=entrypoints/openai/test_completion.py \
--ignore=entrypoints/openai/test_models.py \
--ignore=entrypoints/openai/test_lora_adapters.py \
--ignore=entrypoints/openai/test_return_tokens_as_ids.py \
--ignore=entrypoints/openai/chat_completion/test_root_path.py \
--ignore=entrypoints/openai/test_root_path.py \
--ignore=entrypoints/openai/test_tokenization.py \
--ignore=entrypoints/openai/completion/test_prompt_validation.py "}
--ignore=entrypoints/openai/test_prompt_validation.py "}
fi
if [[ $cmds == *" entrypoints/llm "* ]]; then

View File

@@ -162,7 +162,7 @@ steps:
- tests/entrypoints/test_chat_utils
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/test_chat_utils.py
- label: Entrypoints Integration Test (API Server 2)
@@ -674,12 +674,12 @@ steps:
- vllm/config/model.py
- vllm/model_executor
- tests/model_executor
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
- label: Benchmarks # 11min
timeout_in_minutes: 20
@@ -1143,7 +1143,7 @@ steps:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py
- pytest -v -s entrypoints/openai/test_oot_registration.py
- pytest -v -s models/test_oot_registration.py
- pytest -v -s plugins/lora_resolvers
@@ -1502,7 +1502,7 @@ steps:
- tests/entrypoints/test_chat_utils
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/test_chat_utils.py
- label: Entrypoints Integration Test (API Server 2)
@@ -2133,12 +2133,12 @@ steps:
- vllm/config/model.py
- vllm/model_executor
- tests/model_executor
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
- label: Benchmarks # 11min
timeout_in_minutes: 20
@@ -2735,7 +2735,7 @@ steps:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
@@ -3257,7 +3257,7 @@ steps:
- tests/entrypoints/test_chat_utils
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/test_chat_utils.py
- label: Entrypoints Integration Test (API Server 2)
@@ -3872,12 +3872,12 @@ steps:
- vllm/config/model.py
- vllm/model_executor
- tests/model_executor
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
- label: Benchmarks # 11min
timeout_in_minutes: 20
@@ -4508,7 +4508,7 @@ steps:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins

View File

@@ -34,7 +34,7 @@ steps:
- tests/entrypoints/test_chat_utils
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
- pytest -v -s entrypoints/test_chat_utils.py
mirror:
amd:

View File

@@ -9,9 +9,9 @@ steps:
- vllm/config/model.py
- vllm/model_executor
- tests/model_executor
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py

View File

@@ -36,6 +36,6 @@ steps:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins

2
.github/mergify.yml vendored
View File

@@ -381,7 +381,7 @@ pull_request_rules:
- or:
- files~=^vllm/model_executor/model_loader/tensorizer.py
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
- files~=^tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
- files~=^tests/model_executor/model_loader/tensorizer_loader/
actions:
assign:

View File

@@ -47,8 +47,6 @@ from common import (
is_mla_backend,
)
from vllm.v1.worker.workspace import init_workspace_manager
def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""Run standard attention benchmark (Flash/Triton/FlashInfer)."""
@@ -464,7 +462,7 @@ def main():
parser.add_argument(
"--batch-specs",
nargs="+",
default=None,
default=["q2k", "8q1s1k"],
help="Batch specifications using extended grammar",
)
@@ -480,21 +478,6 @@ def main():
parser.add_argument("--repeats", type=int, default=1, help="Repetitions")
parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations")
parser.add_argument("--profile-memory", action="store_true", help="Profile memory")
parser.add_argument(
"--kv-cache-dtype",
default="auto",
choices=["auto", "fp8"],
help="KV cache dtype: auto or fp8",
)
parser.add_argument(
"--cuda-graphs",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Launch kernels with CUDA graphs to eliminate CPU overhead"
"in measurements (default: True)"
),
)
# Parameter sweep (use YAML config for advanced sweeps)
parser.add_argument(
@@ -553,24 +536,21 @@ def main():
# Batch specs and sizes
# Support both explicit batch_specs and generated batch_spec_ranges
# CLI --batch-specs takes precedence over YAML when provided.
cli_batch_specs_provided = args.batch_specs is not None
if not cli_batch_specs_provided:
if "batch_spec_ranges" in yaml_config:
# Generate batch specs from ranges
generated_specs = generate_batch_specs_from_ranges(
yaml_config["batch_spec_ranges"]
)
# Combine with any explicit batch_specs
if "batch_specs" in yaml_config:
args.batch_specs = yaml_config["batch_specs"] + generated_specs
else:
args.batch_specs = generated_specs
console.print(
f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]"
)
elif "batch_specs" in yaml_config:
args.batch_specs = yaml_config["batch_specs"]
if "batch_spec_ranges" in yaml_config:
# Generate batch specs from ranges
generated_specs = generate_batch_specs_from_ranges(
yaml_config["batch_spec_ranges"]
)
# Combine with any explicit batch_specs
if "batch_specs" in yaml_config:
args.batch_specs = yaml_config["batch_specs"] + generated_specs
else:
args.batch_specs = generated_specs
console.print(
f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]"
)
elif "batch_specs" in yaml_config:
args.batch_specs = yaml_config["batch_specs"]
if "batch_sizes" in yaml_config:
args.batch_sizes = yaml_config["batch_sizes"]
@@ -595,10 +575,6 @@ def main():
args.warmup_iters = yaml_config["warmup_iters"]
if "profile_memory" in yaml_config:
args.profile_memory = yaml_config["profile_memory"]
if "kv_cache_dtype" in yaml_config:
args.kv_cache_dtype = yaml_config["kv_cache_dtype"]
if "cuda_graphs" in yaml_config:
args.cuda_graphs = yaml_config["cuda_graphs"]
# Parameter sweep configuration
if "parameter_sweep" in yaml_config:
@@ -653,18 +629,12 @@ def main():
# Determine backends
backends = args.backends or ([args.backend] if args.backend else ["flash"])
prefill_backends = getattr(args, "prefill_backends", None)
if not args.batch_specs:
args.batch_specs = ["q2k", "8q1s1k"]
console.print(f"Backends: {', '.join(backends)}")
if prefill_backends:
console.print(f"Prefill backends: {', '.join(prefill_backends)}")
console.print(f"Batch specs: {', '.join(args.batch_specs)}")
console.print(f"KV cache dtype: {args.kv_cache_dtype}")
console.print(f"CUDA graphs: {args.cuda_graphs}")
console.print()
init_workspace_manager(args.device)
# Run benchmarks
all_results = []
@@ -717,8 +687,6 @@ def main():
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
kv_cache_dtype=args.kv_cache_dtype,
use_cuda_graphs=args.cuda_graphs,
)
# Add decode pipeline config
@@ -871,8 +839,6 @@ def main():
"repeats": args.repeats,
"warmup_iters": args.warmup_iters,
"profile_memory": args.profile_memory,
"kv_cache_dtype": args.kv_cache_dtype,
"use_cuda_graphs": args.cuda_graphs,
}
all_results = run_model_parameter_sweep(
backends,
@@ -895,8 +861,6 @@ def main():
"repeats": args.repeats,
"warmup_iters": args.warmup_iters,
"profile_memory": args.profile_memory,
"kv_cache_dtype": args.kv_cache_dtype,
"use_cuda_graphs": args.cuda_graphs,
}
all_results = run_parameter_sweep(
backends, args.batch_specs, base_config_args, args.parameter_sweep, console
@@ -927,8 +891,6 @@ def main():
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
kv_cache_dtype=args.kv_cache_dtype,
use_cuda_graphs=args.cuda_graphs,
)
result = run_benchmark(config)

View File

@@ -213,9 +213,6 @@ class BenchmarkConfig:
profile_memory: bool = False
use_cuda_graphs: bool = False
# "auto" or "fp8"
kv_cache_dtype: str = "auto"
# MLA-specific
prefill_backend: str | None = None
kv_lora_rank: int | None = None
@@ -372,7 +369,6 @@ class ResultsFormatter:
"backend",
"batch_spec",
"num_layers",
"kv_cache_dtype",
"mean_time",
"std_time",
"throughput",
@@ -386,7 +382,6 @@ class ResultsFormatter:
"backend": r.config.backend,
"batch_spec": r.config.batch_spec,
"num_layers": r.config.num_layers,
"kv_cache_dtype": r.config.kv_cache_dtype,
"mean_time": r.mean_time,
"std_time": r.std_time,
"throughput": r.throughput_tokens_per_sec or 0,

View File

@@ -30,9 +30,9 @@ batch_specs:
- "2q16k_32q1s4k" # 2 very large prefill + 32 decode
# Context extension + decode
- "2q1ks2k_16q1s1k" # 2 extend + 16 decode
- "4q2ks4k_32q1s2k" # 4 extend + 32 decode
- "2q1ks8k_32q1s2k" # 2 large extend + 32 decode
- "2q1kkv2k_16q1s1k" # 2 extend + 16 decode
- "4q2kkv4k_32q1s2k" # 4 extend + 32 decode
- "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode
# Explicitly chunked prefill
- "q8k" # 8k prefill with chunking hint

View File

@@ -1,58 +0,0 @@
# MLA decode-only benchmark configuration
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128 # Base value, can be swept for TP simulation
num_kv_heads: 1 # MLA uses single latent KV
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep:
param_name: "num_q_heads"
values: [128, 64, 32, 16]
label_format: "{backend}_{value}h"
batch_specs:
# Small batches, varying sequence lengths
- "16q1s512" # 16 requests, 512 KV cache
- "16q1s1k" # 16 requests, 1k KV cache
- "16q1s2k" # 16 requests, 2k KV cache
- "16q1s4k" # 16 requests, 4k KV cache
# Medium batches
- "32q1s1k" # 32 requests, 1k KV cache
- "32q1s2k" # 32 requests, 2k KV cache
- "32q1s4k" # 32 requests, 4k KV cache
- "32q1s8k" # 32 requests, 8k KV cache
# Large batches
- "64q1s1k" # 64 requests, 1k KV cache
- "64q1s2k" # 64 requests, 2k KV cache
- "64q1s4k" # 64 requests, 4k KV cache
- "64q1s8k" # 64 requests, 8k KV cache
# Very large batches
- "128q1s1k" # 128 requests, 1k KV cache
- "128q1s2k" # 128 requests, 2k KV cache
- "128q1s4k" # 128 requests, 4k KV cache
- "128q1s8k" # 128 requests, 8k KV cache
# Long context
- "32q1s16k" # 32 requests, 16k KV cache
- "32q1s32k" # 32 requests, 32k KV cache
backends:
- FLASHMLA_SPARSE
- FLASHINFER_MLA_SPARSE
device: "cuda:0"
repeats: 100
warmup_iters: 10
profile_memory: true

View File

@@ -60,11 +60,9 @@ def create_minimal_vllm_config(
model_name: str = "deepseek-v3",
block_size: int = 128,
max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192,
mla_dims: dict | None = None,
index_topk: int | None = None,
prefill_backend: str | None = None,
kv_cache_dtype: str = "auto",
) -> VllmConfig:
"""
Create minimal VllmConfig for MLA benchmarks.
@@ -151,13 +149,13 @@ def create_minimal_vllm_config(
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
cache_dtype=kv_cache_dtype,
cache_dtype="auto",
enable_prefix_caching=False,
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max(max_num_batched_tokens, max_num_seqs),
max_num_batched_tokens=8192,
max_model_len=32768,
is_encoder_decoder=False,
enable_chunked_prefill=True,
@@ -537,7 +535,6 @@ def _create_backend_impl(
device: torch.device,
max_num_tokens: int = 8192,
index_topk: int | None = None,
kv_cache_dtype: str = "auto",
):
"""
Create backend implementation instance.
@@ -586,7 +583,7 @@ def _create_backend_impl(
"num_kv_heads": mla_dims["num_kv_heads"],
"alibi_slopes": None,
"sliding_window": None,
"kv_cache_dtype": kv_cache_dtype,
"kv_cache_dtype": "auto",
"logits_soft_cap": None,
"attn_type": "decoder",
"kv_sharing_target_layer_name": None,
@@ -704,7 +701,6 @@ def _run_single_benchmark(
mla_dims: dict,
device: torch.device,
indexer=None,
kv_cache_dtype: str | None = None,
) -> BenchmarkResult:
"""
Run a single benchmark iteration.
@@ -738,124 +734,49 @@ def _run_single_benchmark(
)
# Create KV cache
if kv_cache_dtype is None:
kv_cache_dtype = getattr(config, "kv_cache_dtype", "auto")
head_size = mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"]
if kv_cache_dtype == "fp8_ds_mla":
# FlashMLA sparse custom format: 656 bytes per token, stored as uint8.
# Layout: kv_lora_rank fp8 bytes + 4 float32 tile scales
# + 2*rope_dim bf16 bytes
# = 512 + 16 + 128 = 656 bytes for DeepSeek dims.
kv_cache = torch.zeros(
num_blocks,
block_size,
656,
device=device,
dtype=torch.uint8,
)
elif kv_cache_dtype == "fp8":
from vllm.platforms import current_platform
kv_cache = torch.zeros(
num_blocks,
block_size,
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
device=device,
dtype=torch.bfloat16,
)
kv_cache = torch.zeros(
num_blocks,
block_size,
head_size,
device=device,
dtype=torch.uint8,
).view(current_platform.fp8_dtype())
else:
kv_cache = torch.zeros(
num_blocks,
block_size,
head_size,
device=device,
dtype=torch.bfloat16,
)
# Create input tensors for both decode and prefill modes
decode_inputs, prefill_inputs = _create_input_tensors(
total_q,
mla_dims,
backend_cfg["query_format"],
device,
torch.bfloat16,
)
# Fill indexer with random indices for sparse backends
is_sparse = backend_cfg.get("is_sparse", False)
if is_sparse and indexer is not None:
indexer.fill_random_indices(total_q, max_kv_len)
# Determine which forward methods to use based on metadata.
# Sparse MLA backends always use forward_mqa
has_decode = is_sparse or getattr(metadata, "decode", None) is not None
has_prefill = not is_sparse and getattr(metadata, "prefill", None) is not None
if not has_decode and not has_prefill:
# Determine which forward method to use based on metadata
if metadata.decode is not None:
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
elif metadata.prefill is not None:
forward_fn = lambda: impl.forward_mha(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_inputs["output"],
)
else:
raise RuntimeError("Metadata has neither decode nor prefill metadata")
num_decode = (
metadata.num_decode_tokens
if (has_decode and has_prefill)
else total_q
if has_decode
else 0
)
num_prefill = total_q - num_decode
# Some backends requires fp8 queries when using fp8 KV cache.
is_fp8_kvcache = kv_cache_dtype.startswith("fp8")
quantize_query = is_fp8_kvcache and getattr(
impl, "supports_quant_query_input", False
)
# quantize_query forces concat format
query_fmt = "concat" if quantize_query else backend_cfg["query_format"]
# Create decode query tensors
if has_decode:
decode_inputs, _ = _create_input_tensors(
num_decode, mla_dims, query_fmt, device, torch.bfloat16
)
# Cast decode query to fp8 if the backend supports it
if quantize_query:
from vllm.platforms import current_platform
if isinstance(decode_inputs, tuple):
decode_inputs = torch.cat(list(decode_inputs), dim=-1)
decode_inputs = decode_inputs.to(current_platform.fp8_dtype())
# Create prefill input tensors
if has_prefill:
_, prefill_inputs = _create_input_tensors(
num_prefill, mla_dims, query_fmt, device, torch.bfloat16
)
# Build forward function
def forward_fn():
results = []
if has_decode:
results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer))
if has_prefill:
results.append(
impl.forward_mha(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_inputs["output"],
)
)
return results[0] if len(results) == 1 else tuple(results)
# Warmup
for _ in range(config.warmup_iters):
forward_fn()
torch.accelerator.synchronize()
# Optionally capture a CUDA graph after warmup.
# Graph replay eliminates CPU launch overhead so timings reflect pure
# kernel time.
if config.use_cuda_graphs:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
forward_fn()
benchmark_fn = graph.replay
else:
benchmark_fn = forward_fn
# Benchmark
times = []
for _ in range(config.repeats):
@@ -864,7 +785,7 @@ def _run_single_benchmark(
start.record()
for _ in range(config.num_layers):
benchmark_fn()
forward_fn()
end.record()
torch.accelerator.synchronize()
@@ -931,30 +852,13 @@ def _run_mla_benchmark_batched(
# Determine if this is a sparse backend
is_sparse = backend_cfg.get("is_sparse", False)
# Extract kv_cache_dtype from the first config
kv_cache_dtype = getattr(first_config, "kv_cache_dtype", "auto")
# FlashMLA sparse only supports "fp8_ds_mla" internally (not generic "fp8").
# Remap here so the user can pass --kv-cache-dtype fp8 regardless of backend.
if backend.upper() == "FLASHMLA_SPARSE" and kv_cache_dtype == "fp8":
kv_cache_dtype = "fp8_ds_mla"
# Compute max total_q across all configs so the metadata builder buffer
# and scheduler config are large enough for all batch specs.
max_total_q = max(
sum(r.q_len for r in parse_batch_spec(cfg.batch_spec))
for cfg, *_ in configs_with_params
)
# Create and set vLLM config for MLA (reused across all benchmarks)
vllm_config = create_minimal_vllm_config(
model_name="deepseek-v3", # Used only for model path
block_size=block_size,
max_num_batched_tokens=max_total_q,
mla_dims=mla_dims, # Use custom dims from config or default
index_topk=index_topk if is_sparse else None,
prefill_backend=prefill_backend,
kv_cache_dtype=kv_cache_dtype,
)
results = []
@@ -979,9 +883,7 @@ def _run_mla_benchmark_batched(
mla_dims,
vllm_config,
device,
max_num_tokens=max_total_q,
index_topk=index_topk if is_sparse else None,
kv_cache_dtype=kv_cache_dtype,
)
# Verify the actual prefill backend matches what was requested
@@ -1040,7 +942,6 @@ def _run_mla_benchmark_batched(
mla_dims,
device,
indexer=indexer,
kv_cache_dtype=kv_cache_dtype,
)
results.append(result)

View File

@@ -140,7 +140,7 @@ def _create_vllm_config(
cache_config = CacheConfig(
block_size=config.block_size,
cache_dtype=config.kv_cache_dtype,
cache_dtype="auto",
)
cache_config.num_gpu_blocks = max_num_blocks
cache_config.num_cpu_blocks = 0
@@ -215,7 +215,7 @@ def _create_backend_impl(
num_kv_heads=config.num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=config.kv_cache_dtype,
kv_cache_dtype="auto",
)
kv_cache_spec = FullAttentionSpec(
@@ -288,22 +288,12 @@ def _create_input_tensors(
total_q: int,
device: torch.device,
dtype: torch.dtype,
quantize_query: bool = False,
) -> tuple:
"""Create Q, K, V input tensors for all layers.
When quantize_query is True, queries are cast to fp8 to match backends
that require query/key/value dtype consistency.
"""
q_dtype = dtype
if quantize_query:
from vllm.platforms import current_platform
q_dtype = current_platform.fp8_dtype()
"""Create Q, K, V input tensors for all layers."""
q_list = [
torch.randn(
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
).to(q_dtype)
)
for _ in range(config.num_layers)
]
k_list = [
@@ -354,17 +344,10 @@ def _create_kv_cache(
# Compute inverse permutation to get back to logical view
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
# Use fp8 dtype for cache when requested.
cache_dtype = dtype
if config.kv_cache_dtype == "fp8":
from vllm.platforms import current_platform
cache_dtype = current_platform.fp8_dtype()
cache_list = []
for _ in range(config.num_layers):
# Allocate in physical layout order (contiguous in memory)
cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
# Permute to logical view
cache = cache.permute(*inv_order)
cache_list.append(cache)
@@ -409,37 +392,6 @@ def _run_single_benchmark(
)
torch.accelerator.synchronize()
# Optionally capture a CUDA graph after warmup.
# Graph replay eliminates CPU launch overhead so timings reflect pure
# kernel time.
if config.use_cuda_graphs:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
benchmark_fn = graph.replay
else:
def benchmark_fn():
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
# Benchmark
times = []
for _ in range(config.repeats):
@@ -447,7 +399,16 @@ def _run_single_benchmark(
end = torch.cuda.Event(enable_timing=True)
start.record()
benchmark_fn()
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
end.record()
torch.accelerator.synchronize()
@@ -541,12 +502,8 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
common_attn_metadata=common_metadata,
)
# Only quantize queries when the impl supports it
quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr(
impl, "supports_quant_query_input", False
)
q_list, k_list, v_list = _create_input_tensors(
config, total_q, device, dtype, quantize_query=quantize_query
config, total_q, device, dtype
)
cache_list = _create_kv_cache(

View File

@@ -286,15 +286,6 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
"Outer scale stride must be 1 when scales are not transposed");
}
int64_t hidden_size = input.size(-1);
TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0,
"hidden_size must be a positive multiple of group_size");
int64_t num_tokens = input.numel() / hidden_size;
int64_t num_groups = hidden_size / group_size;
TORCH_CHECK(scales.numel() >= num_tokens * num_groups,
"scales buffer too small: need ", num_tokens * num_groups,
" elements, got ", scales.numel());
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
var_epsilon, scale_ub, residual,
is_scale_transposed);

View File

@@ -107,27 +107,6 @@ vLLM supports the `tool_choice='none'` option in the chat completion API. When t
!!! note
When tools are specified in the request, vLLM includes tool definitions in the prompt by default, regardless of the `tool_choice` setting. To exclude tool definitions when `tool_choice='none'`, use the `--exclude-tools-when-tool-choice-none` option.
## Constrained Decoding Behavior
Whether vLLM enforces the tool parameter schema during generation depends on the `tool_choice` mode:
| `tool_choice` value | Schema-constrained decoding | Behavior |
| --- | --- | --- |
| Named function | Yes (via structured outputs backend) | Arguments are guaranteed to be valid JSON conforming to the function's parameter schema. |
| `"required"` | Yes (via structured outputs backend) | Same as named function. The model must produce at least one tool call. |
| `"auto"` | No | The model generates freely. A tool-call parser extracts tool calls from the raw text. Arguments may be malformed or not match the schema. |
| `"none"` | N/A | No tool calls are produced. |
When schema conformance matters, prefer `tool_choice="required"` or named function calling over `"auto"`.
### Strict Mode (`strict` parameter)
The [OpenAI API](https://platform.openai.com/docs/guides/function-calling#strict-mode) supports a `strict` field on function definitions. When set to `true`, OpenAI uses constrained decoding to guarantee that tool-call arguments match the function schema, even in `tool_choice="auto"` mode.
vLLM **does not implement** `strict` mode today. The `strict` field is accepted in requests (to avoid breaking clients that set it), but it has no effect on decoding behavior. In auto mode, argument validity depends entirely on the model's output quality and the parser's extraction logic.
Tracking issues: [#15526](https://github.com/vllm-project/vllm/issues/15526), [#16313](https://github.com/vllm-project/vllm/issues/16313).
## Automatic Function Calling
To enable this feature, you should set the following flags:
@@ -145,9 +124,6 @@ from HuggingFace; and you can find an example of this in a `tokenizer_config.jso
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
!!! note
With `tool_choice="auto"`, tool-call arguments are extracted from the model's raw text output by the selected parser. No schema-level constraint is applied during decoding, so arguments may occasionally be malformed or violate the function's parameter schema. See [Constrained Decoding Behavior](#constrained-decoding-behavior) for details.
### Hermes Models (`hermes`)
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.

View File

@@ -50,7 +50,7 @@ av==16.1.0
blobfile==3.0.0
# Multi-Modal Models Test
decord==0.6.0
# video processing, required by entrypoints/openai/chat_completion/test_video.py
# video processing, required by entrypoints/openai/test_video.py
rapidfuzz==3.12.1
# OpenAI compatibility and testing

View File

@@ -1,9 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from tests.entrypoints.openai.chat_completion.test_oot_registration import (
run_and_test_dummy_opt_api_server,
)
from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server
def test_distributed_oot(dummy_opt_path: str):

View File

@@ -4,11 +4,12 @@ import weakref
import pytest
from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS
from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.sampling_params import SamplingParams
from ..openai.test_vision import TEST_IMAGE_ASSETS
@pytest.fixture(scope="function")
def text_llm():

View File

@@ -6,12 +6,13 @@ import logging
import pytest
import regex as re
from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS
from vllm import LLM
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.v1.metrics import loggers as stat_loggers
from vllm.v1.metrics.reader import Counter, Metric
from ..openai.test_vision import TEST_IMAGE_ASSETS
def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]:
return [

View File

@@ -7,10 +7,11 @@ import openai
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.assets.audio import AudioAsset
from vllm.multimodal.utils import encode_audio_base64, encode_audio_url, fetch_audio
from ...utils import RemoteOpenAIServer
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
TEST_AUDIO_URLS = [
AudioAsset("winning_call").url,

View File

@@ -8,8 +8,8 @@ import openai
import pytest
import pytest_asyncio
from tests.conftest import VideoTestAssets
from tests.utils import RemoteOpenAIServer
from ...conftest import VideoTestAssets
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-Omni-3B"

View File

@@ -14,7 +14,7 @@ import torch
from openai import BadRequestError
from transformers import AutoConfig
from tests.utils import RemoteOpenAIServer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"

View File

@@ -8,8 +8,8 @@ import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download
from tests.conftest import AudioTestAssets
from tests.utils import RemoteOpenAIServer
from ...conftest import AudioTestAssets
from ...utils import RemoteOpenAIServer
# NOTE - the tests in this module are currently analogous to test_chat, but are
# separated to avoid OOM killing due to module-scoped servers, since we

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from tests.utils import VLLM_PATH, RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

View File

@@ -11,10 +11,11 @@ import pytest
import regex as re
import torch
from tests.utils import RemoteOpenAIServer
from vllm.config import ModelConfig
from vllm.renderers.embed_utils import safe_load_prompt_embeds
from ...utils import RemoteOpenAIServer
@pytest.mark.asyncio
async def test_empty_prompt():

View File

@@ -8,7 +8,7 @@ from typing import Any, NamedTuple
import openai # use the official client for correctness check
import pytest
from tests.utils import RemoteOpenAIServer
from ...utils import RemoteOpenAIServer
# # any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"

View File

@@ -9,7 +9,6 @@ import pytest
import pytest_asyncio
import torch.cuda
from tests.utils import RemoteOpenAIServer
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig,
@@ -18,6 +17,8 @@ from vllm.model_executor.model_loader.tensorizer import (
)
from vllm.platforms import current_platform
from ...utils import RemoteOpenAIServer
MODEL_NAME = "unsloth/llama-3.2-1b-Instruct"
LORA_PATH = "davzoku/finqa_adapter_1b"

View File

@@ -6,10 +6,11 @@ import tempfile
import pytest
from tests.utils import RemoteOpenAIServer
from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf
from vllm.tokenizers import get_tokenizer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-0.6B"
MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b")

View File

@@ -7,10 +7,11 @@ import openai
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_video_url, fetch_video
from vllm.platforms import current_platform
from ...utils import RemoteOpenAIServer
MODEL_NAME = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
MAXIMUM_VIDEOS = 3

View File

@@ -8,11 +8,12 @@ import pytest
import pytest_asyncio
from transformers import AutoProcessor
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
from vllm.multimodal.media import MediaWithBytes
from vllm.multimodal.utils import encode_image_url, fetch_image
from vllm.platforms import current_platform
from ...utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
MODEL_NAME = "microsoft/Phi-3.5-vision-instruct"
MAXIMUM_IMAGES = 2

View File

@@ -8,9 +8,10 @@ import pytest
import requests
import torch
from tests.utils import RemoteOpenAIServer
from vllm.utils.serial_utils import tensor2base64
from ...utils import RemoteOpenAIServer
@pytest.mark.parametrize(
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]

View File

@@ -5,7 +5,7 @@ import json
import pytest
from tests.tool_parsers.utils import (
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)

View File

@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.utils import (
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from tests.tool_parsers.utils import (
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from tests.tool_parsers.utils import (
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from tests.tool_parsers.utils import (
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)

View File

@@ -1,434 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Standalone unit tests for trtllm_prefill_attn_kvfp8_dequant.
Tests both contiguous and non-contiguous (cross-layer unified) KV cache
layouts against a pure-PyTorch reference implementation.
"""
import pytest
import torch
from vllm.platforms import current_platform
FP8_DTYPE = current_platform.fp8_dtype()
NUM_BLOCKS = 128
def to_float8(x, dtype=None):
if dtype is None:
dtype = FP8_DTYPE
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
def make_contiguous_kv_cache(num_blocks, num_kv_heads, block_size, head_size):
"""Create a standard contiguous fp8 KV cache (HND layout)."""
raw = torch.randn(
num_blocks,
2,
num_kv_heads,
block_size,
head_size,
dtype=torch.bfloat16,
device="cuda",
)
kv_cache, scale = to_float8(raw)
return kv_cache, scale
def make_cross_layer_kv_cache(
num_blocks,
num_kv_heads,
block_size,
head_size,
num_layers=4,
):
"""
Create a non-contiguous per-layer view mimicking cross-layer allocation.
Physical layout: (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
Returned view: (num_blocks, 2, num_kv_heads, block_size, head_size)
with non-contiguous strides on dims 0, 1, 2 (they skip over num_layers).
"""
raw = torch.randn(
num_blocks,
2,
num_kv_heads,
num_layers,
block_size,
head_size,
dtype=torch.bfloat16,
device="cuda",
)
fp8_full, scale = to_float8(raw)
layer_view = fp8_full[:, :, :, 0, :, :]
assert not layer_view.is_contiguous(), (
f"Expected non-contiguous view, got strides {layer_view.stride()}"
)
return layer_view, scale
def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype):
"""Pure PyTorch reference: gather pages and dequantize fp8 -> dequant_dtype."""
batch_size, num_pages_per_seq = block_tables.shape
s = kv_cache.shape
out = torch.zeros(
batch_size * num_pages_per_seq + 1,
s[1],
s[2],
s[3],
s[4],
dtype=dequant_dtype,
device=kv_cache.device,
)
for b in range(batch_size):
for p in range(num_pages_per_seq):
page_idx = block_tables[b, p].item()
if page_idx <= 0:
continue
mock_idx = b * num_pages_per_seq + p + 1
out[mock_idx, 0] = (kv_cache[page_idx, 0].float() * k_scale.item()).to(
dequant_dtype
)
out[mock_idx, 1] = (kv_cache[page_idx, 1].float() * v_scale.item()).to(
dequant_dtype
)
return out
@pytest.mark.parametrize("num_kv_heads", [1, 8])
@pytest.mark.parametrize("head_size", [64, 128])
@pytest.mark.parametrize("block_size", [16, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("num_pages_per_seq", [3, 8])
@pytest.mark.parametrize("contiguous", [True, False])
@torch.inference_mode()
def test_trtllm_kvfp8_dequant(
num_kv_heads: int,
head_size: int,
block_size: int,
batch_size: int,
num_pages_per_seq: int,
contiguous: bool,
):
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
if contiguous:
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
else:
kv_cache, scale = make_cross_layer_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = scale.clone()
v_scale = scale.clone()
block_tables = torch.randint(
1,
NUM_BLOCKS,
(batch_size, num_pages_per_seq),
dtype=torch.int32,
)
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
expected_bt = torch.arange(
1,
batch_size * num_pages_per_seq + 1,
dtype=torch.int32,
device="cuda",
).reshape(batch_size, num_pages_per_seq)
torch.testing.assert_close(mock_block_table, expected_bt)
# Page 0 is padding (never written), compare only pages 1+
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_block_tables_with_zero_pages():
"""Pages with index <= 0 must be skipped (early return in kernel)."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
# Mix of valid pages and zeros (padding)
block_tables = torch.tensor(
[[5, 0, 10], [0, 0, 0], [3, 7, 0]],
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
# Only compare pages that were actually written (non-zero page indices)
for b in range(block_tables.shape[0]):
for p in range(block_tables.shape[1]):
if block_tables[b, p].item() > 0:
idx = b * block_tables.shape[1] + p + 1
torch.testing.assert_close(
mock_kv_cache[idx],
ref[idx],
atol=1e-3,
rtol=1e-3,
)
@torch.inference_mode()
def test_all_zero_block_tables():
"""All-zero block_tables: kernel should write nothing."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 4, 16, 64
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
block_tables = torch.zeros(2, 4, dtype=torch.int32, device="cuda")
# Should not crash even though no pages are valid
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
assert mock_kv_cache.shape[0] == 2 * 4 + 1
assert mock_block_table.shape == (2, 4)
@torch.inference_mode()
def test_different_k_v_scales():
"""Verify K and V are dequantized with independent scales."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
kv_cache, _ = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
v_scale = torch.tensor([2.0], dtype=torch.float32, device="cuda")
block_tables = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda")
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_single_page_per_seq():
"""Minimum grid dim 1 = 1 page per sequence."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 128
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
block_tables = torch.tensor([[5], [10], [20]], dtype=torch.int32, device="cuda")
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_large_page_indices():
"""Page indices near the top of the buffer stress offset arithmetic."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 128
large_num_blocks = 32768
kv_cache, scale = make_contiguous_kv_cache(
large_num_blocks,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
# Use page indices near the top of the buffer
block_tables = torch.tensor(
[[large_num_blocks - 1, large_num_blocks - 2, 1]],
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_large_block_size():
"""block_size=64 -> HEAD_STRIDE=8192, large tl.arange per thread block."""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 4, 64, 128
kv_cache, scale = make_contiguous_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
)
k_scale = v_scale = scale.clone()
block_tables = torch.randint(
1,
NUM_BLOCKS,
(2, 4),
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
@torch.inference_mode()
def test_cross_layer_many_layers():
"""
Non-contiguous with 36 layers -- matches real gpt-oss-120b.
Strides are far from contiguous (factor of 36 in the gaps).
"""
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
)
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
num_layers = 36
kv_cache, scale = make_cross_layer_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
head_size,
num_layers=num_layers,
)
k_scale = v_scale = scale.clone()
block_tables = torch.randint(
1,
NUM_BLOCKS,
(4, 6),
dtype=torch.int32,
device="cuda",
)
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
kv_cache,
block_tables,
k_scale,
v_scale,
torch.bfloat16,
)
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)

View File

@@ -280,22 +280,21 @@ def test_rms_norm(
assert torch.allclose(ref_residual, ops_residual)
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
if group_size is None:
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
opcheck(
torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
)
else:
assert hidden_size % group_size[1] == 0
num_groups = hidden_size // group_size[1]
scales = torch.empty(
(num_groups, num_tokens),
device=x.device,
dtype=torch.float32,
).transpose(0, 1)
# TODO(luka/eliza) opcheck is broken?
# Somehow the cloned args are getting mutated in-place,
# which causes the opcheck to fail.
# https://github.com/vllm-project/vllm/issues/36688
return
opcheck(
torch.ops._C.rms_norm_per_block_quant,
(

View File

@@ -1,104 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""E2E tests for online MXFP8 quantization.
Loads a BF16 model with ``--quantization mxfp8`` (online quantization) and
compares log-probabilities against the same model served in BF16 without
quantization. This exercises the full pipeline: config parsing,
``Mxfp8OnlineLinearMethod``, ``Mxfp8OnlineMoEMethod``, weight loading,
online quantization / shuffling, and inference through ``apply_monolithic``.
Layer skipping (``modules_to_not_convert``) is configured in the model's
``config.json`` under ``quantization_config`` and is not tested here.
``example_prompts`` is a pytest fixture (from conftest.py) that loads 8
diverse prompts from ``tests/prompts/example.txt``.
"""
import pytest
from tests.quantization.utils import is_quant_method_supported
from ..utils import check_logprobs_close
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
MOE_MODEL = "Qwen/Qwen3-30B-A3B"
# A small dense model (no MoE) to validate the linear-only path.
DENSE_MODEL = "Qwen/Qwen3-0.6B"
MAX_MODEL_LEN = 1024
MAX_TOKENS = 4
NUM_LOG_PROBS = 8
@pytest.mark.skipif(
not is_quant_method_supported("mxfp8"),
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
)
@pytest.mark.quant_model
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
def test_mxfp8_logprobs(
vllm_runner,
example_prompts,
model: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Compare BF16 baseline logprobs against online MXFP8-quantized model.
Runs the same model twice -- once in BF16 (baseline) and once with
online MXFP8 quantization -- then checks that the top log-probabilities
are close. Only 4 tokens are generated to keep the test fast while
still catching numerical divergence.
"""
with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true")
with vllm_runner(
model,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
)
with vllm_runner(
model,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
quantization="mxfp8",
) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
)
check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=test_outputs,
name_0="bf16",
name_1="mxfp8",
)
@pytest.mark.skipif(
not is_quant_method_supported("mxfp8"),
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
)
@pytest.mark.quant_model
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
def test_mxfp8_generation(vllm_runner, model: str) -> None:
"""Smoke test: verify online MXFP8 model generates coherent text."""
prompt = "1 2 3 4 5"
with vllm_runner(
model,
enforce_eager=True,
quantization="mxfp8",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
output = vllm_model.generate_greedy([prompt], max_tokens=5)
generated = output[0][1]
assert len(generated) > len(prompt), (
f"MXFP8 model produced no new tokens. Output: {generated!r}"
)

View File

@@ -1,378 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from dataclasses import dataclass, field
from types import NoneType
from typing import Any
import pytest
from tests.tool_parsers.utils import run_tool_extraction
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParserManager
@dataclass
class ToolParserTestConfig:
"""Configuration for a tool parser's common tests.
This dataclass contains all the test data and expected results needed
to run the common test suite for a parser. Each parser test file
creates one instance of this config with parser-specific values.
Attributes:
parser_name: Name used with ToolParserManager (e.g., "mistral")
Test data (model outputs):
no_tool_calls_output: Plain text without any tool syntax
single_tool_call_output: One tool call with simple arguments
parallel_tool_calls_output: Multiple tool calls in one response
various_data_types_output: Tool with various data types
empty_arguments_output: Tool call with no parameters
surrounding_text_output: Tool call mixed with regular text
escaped_strings_output: Tool call with escaped chars
malformed_input_outputs: List of invalid inputs
Expected results:
single_tool_call_expected_name: Expected function name
single_tool_call_expected_args: Expected arguments dict
parallel_tool_calls_count: Number of tools in parallel test
parallel_tool_calls_names: Function names in order
single_tool_call_expected_content: Content field when tool called
parallel_tool_calls_expected_content: Content for parallel test
xfail markers:
xfail_streaming: Mapping test name to xfail reason (streaming only)
xfail_nonstreaming: Mapping test name to xfail reason (non-streaming)
Special flags:
allow_empty_or_json_empty_args: True if "" or "{}" both valid for empty args
supports_typed_arguments: True if the parser supports typed function arguments
"""
# Parser identification
parser_name: str
# Test data - model outputs for each common test
no_tool_calls_output: str
single_tool_call_output: str
parallel_tool_calls_output: str
various_data_types_output: str
empty_arguments_output: str
surrounding_text_output: str
escaped_strings_output: str
malformed_input_outputs: list[str]
# Expected results for specific tests (optional overrides)
single_tool_call_expected_name: str = "get_weather"
single_tool_call_expected_args: dict[str, Any] = field(
default_factory=lambda: {"city": "Tokyo"}
)
parallel_tool_calls_count: int = 2
parallel_tool_calls_names: list[str] = field(
default_factory=lambda: ["get_weather", "get_time"]
)
# xfail configuration - maps test name to xfail reason
xfail_streaming: dict[str, str] = field(default_factory=dict)
xfail_nonstreaming: dict[str, str] = field(default_factory=dict)
# Content expectations (some parsers strip content, others don't)
single_tool_call_expected_content: str | None = None
parallel_tool_calls_expected_content: str | None = None
# Special assertions for edge cases
allow_empty_or_json_empty_args: bool = True # "{}" or "" for empty args
supports_typed_arguments: bool = True
class ToolParserTests:
"""Mixin class providing common test suite for tool parsers.
To use this mixin in a parser test file:
1. Create a test_config fixture that returns a ToolParserTestConfig instance
2. Inherit from this class
3. Add parser-specific tests as additional methods
Example:
class TestMistralToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="mistral",
no_tool_calls_output="Plain text...",
# ... other config ...
)
# Parser-specific tests
def test_mistral_specific_feature(self, tool_parser):
# Custom test logic
pass
"""
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
"""Override this to provide parser-specific configuration."""
raise NotImplementedError(
"Subclass must provide test_config fixture returning ToolParserTestConfig"
)
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Override this to provide parser-specific tokenizer."""
return default_tokenizer
@pytest.fixture
def tool_parser(self, test_config: ToolParserTestConfig, tokenizer: TokenizerLike):
return ToolParserManager.get_tool_parser(test_config.parser_name)(tokenizer)
@pytest.fixture(params=[True, False])
def streaming(self, request: pytest.FixtureRequest) -> bool:
return request.param
def test_no_tool_calls(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles plain text without tool syntax."""
# Apply xfail markers if configured
test_name = "test_no_tool_calls"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.no_tool_calls_output, streaming=streaming
)
assert content == test_config.no_tool_calls_output, (
f"Expected content to match input, got {content}"
)
assert len(tool_calls) == 0, f"Expected no tool calls, got {len(tool_calls)}"
def test_single_tool_call_simple_args(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser extracts one tool with simple arguments."""
# Apply xfail markers if configured
test_name = "test_single_tool_call_simple_args"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.single_tool_call_output, streaming=streaming
)
# Content check (some parsers strip it)
if test_config.single_tool_call_expected_content is not None:
assert content == test_config.single_tool_call_expected_content
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
assert tool_calls[0].type == "function"
assert tool_calls[0].function.name == test_config.single_tool_call_expected_name
args = json.loads(tool_calls[0].function.arguments)
for key, value in test_config.single_tool_call_expected_args.items():
assert args.get(key) == value, (
f"Expected {key}={value}, got {args.get(key)}"
)
def test_parallel_tool_calls(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles multiple tools in one response."""
# Apply xfail markers if configured
test_name = "test_parallel_tool_calls"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser,
test_config.parallel_tool_calls_output,
streaming=streaming,
)
assert len(tool_calls) == test_config.parallel_tool_calls_count, (
f"Expected {test_config.parallel_tool_calls_count} "
f"tool calls, got {len(tool_calls)}"
)
# Verify tool names match expected
for i, expected_name in enumerate(test_config.parallel_tool_calls_names):
assert tool_calls[i].type == "function"
assert tool_calls[i].function.name == expected_name
# Verify unique IDs
ids = [tc.id for tc in tool_calls]
assert len(ids) == len(set(ids)), "Tool call IDs should be unique"
def test_various_data_types(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles all JSON types in arguments."""
# Apply xfail markers if configured
test_name = "test_various_data_types"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser,
test_config.various_data_types_output,
streaming=streaming,
)
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
args = json.loads(tool_calls[0].function.arguments)
# Verify all expected fields present
required_fields_types = {
"string_field": str,
"int_field": int,
"float_field": float,
"bool_field": bool,
"null_field": NoneType,
"array_field": list,
"object_field": dict,
}
for required_field, expected_type in required_fields_types.items():
assert required_field in args, (
f"Expected field '{required_field}' in arguments"
)
if test_config.supports_typed_arguments:
found_type = type(args[required_field])
assert found_type is expected_type, (
f"Expected field '{required_field}' to have type {expected_type}, "
f"got {found_type}"
)
def test_empty_arguments(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles parameterless tool calls."""
# Apply xfail markers if configured
test_name = "test_empty_arguments"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.empty_arguments_output, streaming=streaming
)
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
args = tool_calls[0].function.arguments
if test_config.allow_empty_or_json_empty_args:
assert args in ["{}", ""], f"Expected empty args, got {args}"
else:
assert args == "{}", f"Expected {{}}, got {args}"
def test_surrounding_text(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser extracts tools from mixed content."""
# Apply xfail markers if configured
test_name = "test_surrounding_text"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.surrounding_text_output, streaming=streaming
)
assert len(tool_calls) >= 1, (
f"Expected at least 1 tool call, got {len(tool_calls)}"
)
def test_escaped_strings(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles escaped characters in arguments."""
# Apply xfail markers if configured
test_name = "test_escaped_strings"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.escaped_strings_output, streaming=streaming
)
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
args = json.loads(tool_calls[0].function.arguments)
# At minimum, verify we can parse and have expected fields
# Exact escaping behavior varies by parser
assert len(args) > 0, "Expected some arguments with escaped strings"
def test_malformed_input(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser gracefully handles invalid syntax."""
# Apply xfail markers if configured
test_name = "test_malformed_input"
self.apply_xfail_mark(request, test_config, test_name, streaming)
for malformed_input in test_config.malformed_input_outputs:
# Should not raise exception
content, tool_calls = run_tool_extraction(
tool_parser, malformed_input, streaming=streaming
)
# Parser should handle gracefully (exact behavior varies)
def test_streaming_reconstruction(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
):
"""Verify streaming produces same result as non-streaming."""
test_name = "test_streaming_reconstruction"
self.apply_xfail_mark(request, test_config, test_name, True)
test_output = test_config.single_tool_call_output
# Non-streaming result
content_non, tools_non = run_tool_extraction(
tool_parser, test_output, streaming=False
)
# Streaming result
content_stream, tools_stream = run_tool_extraction(
tool_parser, test_output, streaming=True
)
# Compare results
assert content_non == content_stream, "Content should match between modes"
assert len(tools_non) == len(tools_stream), "Tool count should match"
if len(tools_non) > 0:
assert tools_non[0].function.name == tools_stream[0].function.name
assert tools_non[0].function.arguments == tools_stream[0].function.arguments
def apply_xfail_mark(self, request, test_config, test_name, streaming):
reason = None
if streaming and test_name in test_config.xfail_streaming:
reason = test_config.xfail_streaming[test_name]
elif not streaming and test_name in test_config.xfail_nonstreaming:
reason = test_config.xfail_nonstreaming[test_name]
if reason is not None:
mark = pytest.mark.xfail(reason=reason, strict=True)
request.node.add_marker(mark)

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="module")
def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2")

View File

@@ -1,92 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
class TestDeepSeekV3ToolParser(ToolParserTests):
@pytest.fixture(scope="class")
def tokenizer(self) -> TokenizerLike:
return get_tokenizer("deepseek-ai/DeepSeek-V3")
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="deepseek_v3",
# Test data
no_tool_calls_output=(
"How can I help you today? I can check weather for you."
),
single_tool_call_output="""<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"city": "Tokyo", "unit": "celsius"}
```<tool▁call▁end><tool▁calls▁end>""",
parallel_tool_calls_output="""<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"city": "Tokyo", "unit": "celsius"}
```<tool▁call▁end><tool▁call▁begin>function<tool▁sep>search_hotels
```json
{"location": "Tokyo", "check_in": "2025-01-15"}
```<tool▁call▁end><tool▁calls▁end>""",
various_data_types_output=(
"""<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>test_function
```json
"""
"""{"string_field": "hello", "int_field": 42, "float_field": 3.14, """
""""bool_field": true, "null_field": null, """
""""array_field": ["a", "b", "c"], """
""""object_field": {"nested": "value"}, """
""""empty_array": [], "empty_object": {}}
```<tool▁call▁end><tool▁calls▁end>"""
),
empty_arguments_output="""<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_current_time
```json
{}
```<tool▁call▁end><tool▁calls▁end>""",
surrounding_text_output=(
"""Let me check the weather for you."""
"""<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"city": "Paris"}
```<tool▁call▁end><tool▁calls▁end>"""
),
escaped_strings_output=(
"""<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>send_message
```json
"""
"""{"text": "He said \\"hello\\"", "path": "C:\\\\Users\\\\file", """
""""newline": "line1\\nline2"}
```<tool▁call▁end><tool▁calls▁end>"""
),
malformed_input_outputs=[
"""<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"city": "Tokyo"
```<tool▁call▁end><tool▁calls▁end>""",
"""<tool▁calls▁begin>function<tool▁sep>get_weather
```json
{"city": "Tokyo"}
```<tool▁calls▁end>""",
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo", "unit": "celsius"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "search_hotels"],
# xfail markers
xfail_streaming={},
xfail_nonstreaming={
"test_malformed_input": (
"Parser sets tools_called=True even when tool_calls is "
"empty (detects start token but fails to parse)"
),
},
)

View File

@@ -1,76 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
class TestGranite20bFcToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="granite-20b-fc",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<function_call> {"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}'
),
parallel_tool_calls_output=(
'<function_call> {"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}\n'
'<function_call> {"name": "get_time", '
'"arguments": {"timezone": "Asia/Tokyo"}}'
),
various_data_types_output="""<function_call> {
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}""",
empty_arguments_output=(
'<function_call> {"name": "refresh", "arguments": {}}'
),
surrounding_text_output="""Let me check the weather for you.
<function_call> {"name": "get_weather", "arguments": {"city": "Tokyo"}}""",
escaped_strings_output="""<function_call> {
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}""",
malformed_input_outputs=[
'<function_call> {"name": "func", "arguments": {',
'<function_call> [{"name": "func", "arguments": {}}]',
'{"name": "func", "arguments": {}}',
'<function_call> {"name": 123}',
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_streaming={
"test_surrounding_text": (
"Granite 20B FC streaming requires <function_call> at start"
),
},
xfail_nonstreaming={},
)

View File

@@ -1,118 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from tests.tool_parsers.utils import run_tool_extraction
class TestGraniteToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="granite",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<|tool_call|> [{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}]'
),
parallel_tool_calls_output="""<|tool_call|> [
{"name": "get_weather", "arguments": {"city": "Tokyo"}},
{"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}}
]""",
various_data_types_output="""<tool_call> [{
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}]""",
empty_arguments_output=(
'<|tool_call|> [{"name": "refresh", "arguments": {}}]'
),
surrounding_text_output="""Let me check the weather for you.
<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]
I'll get that information.""",
escaped_strings_output="""<tool_call> [{
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}]""",
malformed_input_outputs=[
'<|tool_call|> [{"name": "func", "arguments": {',
'<|tool_call|> {"name": "func", "arguments": {}}', # Not an array
'[{"name": "func", "arguments": "not a dict"}]',
'Some text [{"name": "func"}]', # JSON but not tool call format
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
# Granite strips content when tool calls present
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_streaming={
"test_malformed_input": (
"Streaming mode incorrectly creates tool call from malformed JSON"
),
"test_surrounding_text": (
"Parser doesn't handle surrounding text correctly in streaming"
),
"test_streaming_reconstruction": (
"Streaming mode doesn't strip <|tool_call|> marker from content"
),
},
xfail_nonstreaming={
"test_surrounding_text": (
"Parser doesn't handle surrounding text correctly in non-streaming"
),
},
)
# Granite-Specific Tests
@pytest.mark.parametrize("streaming", [True, False])
def test_granite_token_prefix_format(self, tool_parser, streaming):
"""Verify parser handles Granite 3.0 <|tool_call|> token format."""
single_tool_call_token = (
'<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
)
content, tool_calls = run_tool_extraction(
tool_parser, single_tool_call_token, streaming=streaming
)
assert len(tool_calls) == 1, (
f"Expected 1 tool call from token format, got {len(tool_calls)}"
)
assert tool_calls[0].function.name == "get_weather"
@pytest.mark.parametrize("streaming", [True, False])
def test_granite_string_prefix_format(self, tool_parser, streaming):
"""Verify parser handles Granite 3.1 <tool_call> string format."""
single_tool_call_string = (
'<tool_call> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
)
content, tool_calls = run_tool_extraction(
tool_parser, single_tool_call_string, streaming=streaming
)
assert len(tool_calls) == 1, (
f"Expected 1 tool call from string format, got {len(tool_calls)}"
)
assert tool_calls[0].function.name == "get_weather"

View File

@@ -1,122 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike
class TestInternLM2ToolParser(ToolParserTests):
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Add some internlm2 specific tokens to the default vocab."""
tokenizer_vocab = default_tokenizer.get_vocab()
default_tokenizer.get_vocab = MagicMock()
tokenizer_vocab.update(
{
"<|action_start|>": 92540,
"<|plugin|>": 92541,
"<|action_end|>": 92542,
}
)
default_tokenizer.get_vocab.return_value = tokenizer_vocab
return default_tokenizer
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="internlm",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<|action_start|><|plugin|>{"name": "get_weather", '
'"parameters": {"city": "Tokyo"}}<|action_end|>'
),
# InternLM2 doesn't support parallel calls
parallel_tool_calls_output=(
'<|action_start|><|plugin|>{"name": "get_weather", '
'"parameters": {"city": "Tokyo"}}<|action_end|>'
),
various_data_types_output="""<|action_start|><|plugin|>{
"name": "test_function",
"parameters": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}<|action_end|>""",
empty_arguments_output=(
'<|action_start|><|plugin|>{"name": "refresh", '
'"parameters": {}}<|action_end|>'
),
surrounding_text_output=(
"Let me check the weather for you. "
'<|action_start|><|plugin|>{"name": "get_weather", '
'"parameters": {"city": "Tokyo"}}<|action_end|>'
),
escaped_strings_output="""<|action_start|><|plugin|>{
"name": "test_function",
"parameters": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}<|action_end|>""",
malformed_input_outputs=[
'<|action_start|><|plugin|>{"name": "func", "parameters": {',
(
'<|action_start|><|plugin|>{"name": "func", '
'"parameters": "not a dict"}<|action_end|>'
),
"<|action_start|><|plugin|>not json<|action_end|>",
"<|action_start|><|plugin|>",
'<|action_start|>{"name": "func"}',
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=1, # InternLM2 only supports single tool calls
parallel_tool_calls_names=["get_weather"],
# Parser-specific settings
allow_empty_or_json_empty_args=True,
# xfail markers
xfail_streaming={
"test_single_tool_call_simple_args": (
"InternLM2 streaming not fully implemented"
),
"test_parallel_tool_calls": (
"InternLM2 streaming not fully implemented"
),
"test_various_data_types": (
"InternLM2 streaming not fully implemented"
),
"test_empty_arguments": ("InternLM2 streaming not fully implemented"),
"test_surrounding_text": ("InternLM2 streaming not fully implemented"),
"test_escaped_strings": ("InternLM2 streaming not fully implemented"),
"test_streaming_reconstruction": (
"InternLM2 streaming parser returns '<|action_start|' as "
"content instead of None - streaming/non-streaming inconsistency"
),
},
xfail_nonstreaming={
"test_malformed_input": (
"InternLM2 parser raises JSONDecodeError on malformed JSON "
"instead of gracefully handling it"
),
},
)

View File

@@ -1,101 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike
class TestLongCatToolParser(ToolParserTests):
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Add some longcat specific tokens to the default vocab."""
tokenizer = default_tokenizer
tokenizer_vocab = tokenizer.get_vocab()
tokenizer.get_vocab = MagicMock()
tokenizer_vocab.update(
{
"<longcat_tool_call>": 32000,
"</longcat_tool_call>": 32001,
}
)
tokenizer.get_vocab.return_value = tokenizer_vocab
return tokenizer
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="longcat",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<longcat_tool_call>{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>'
),
parallel_tool_calls_output=(
'<longcat_tool_call>{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>\n'
'<longcat_tool_call>{"name": "get_time", '
'"arguments": {"timezone": "Asia/Tokyo"}}</longcat_tool_call>'
),
various_data_types_output="""<longcat_tool_call>{
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}</longcat_tool_call>""",
empty_arguments_output=(
'<longcat_tool_call>{"name": "refresh", "arguments": {}}'
"</longcat_tool_call>"
),
surrounding_text_output=(
"Let me check the weather for you.\n"
'<longcat_tool_call>{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>\n'
"Here is the result."
),
escaped_strings_output="""<longcat_tool_call>{
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}</longcat_tool_call>""",
malformed_input_outputs=[
'<longcat_tool_call>{"name": "func", "arguments": {',
(
'<longcat_tool_call>{"name": "func", '
'"arguments": "not a dict"}</longcat_tool_call>'
),
"Some text with <longcat_tool_call>invalid json",
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_streaming={
"test_malformed_input": "Streaming has complex buffering behavior",
},
xfail_nonstreaming={},
# Configuration
allow_empty_or_json_empty_args=True,
)

View File

@@ -1,110 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike
class TestPhi4MiniToolParser(ToolParserTests):
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Add some phi4mini specific tokens to the default vocab."""
tokenizer = default_tokenizer
tokenizer_vocab = tokenizer.get_vocab()
tokenizer.get_vocab = MagicMock()
tokenizer_vocab.update(
{
"functools": 32000,
}
)
tokenizer.get_vocab.return_value = tokenizer_vocab
return tokenizer
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="phi4_mini_json",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
),
parallel_tool_calls_output="""functools[
{"name": "get_weather", "arguments": {"city": "Tokyo"}},
{"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}}
]""",
various_data_types_output="""functools[{
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}]""",
empty_arguments_output='functools[{"name": "refresh", "arguments": {}}]',
surrounding_text_output="""Let me check the weather for you.
functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}]
Would you like to know more?""",
escaped_strings_output="""functools[{
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}]""",
malformed_input_outputs=[
'functools[{"name": "func", "arguments": {',
'functools[{"name": "func", "arguments": "not a dict"}]',
'functools{"name": "func"}', # Missing brackets
'functools[{"name": "func"}]', # Missing arguments/parameters
"functools[] This is just text", # Empty functools
"functools[ This is just text ]", # functools with invalid JSON
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
# Phi-4 Mini strips content when tool calls present
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
parallel_tool_calls_expected_content=None,
# xfail markers
xfail_streaming={
"test_no_tool_calls": "Phi4 Mini streaming not implemented",
"test_single_tool_call_simple_args": (
"Phi4 Mini streaming not implemented"
),
"test_parallel_tool_calls": "Phi4 Mini streaming not implemented",
"test_various_data_types": "Phi4 Mini streaming not implemented",
"test_empty_arguments": "Phi4 Mini streaming not implemented",
"test_surrounding_text": "Phi4 Mini streaming not implemented",
"test_escaped_strings": "Phi4 Mini streaming not implemented",
"test_streaming_reconstruction": "Phi4 Mini streaming not implemented",
},
xfail_nonstreaming={
"test_various_data_types": (
"Phi4MiniJsonToolParser regex has nesting limitations "
"with nested objects"
),
"test_malformed_input": (
"Phi4MiniJsonToolParser incorrectly sets "
"tools_called=True on empty array"
),
},
)

View File

@@ -1,75 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
class TestQwen3xmlToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="qwen3_xml",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output="<tool_call>\n<function=get_weather>\n<parameter=city>Tokyo</parameter>\n</function>\n</tool_call>",
parallel_tool_calls_output="<tool_call>\n<function=get_weather>\n<parameter=city>Tokyo</parameter>\n</function>\n</tool_call><tool_call>\n<function=get_time>\n<parameter=timezone>Asia/Tokyo</parameter>\n</function>\n</tool_call>",
various_data_types_output=(
"<tool_call>\n<function=test_function>\n"
"<parameter=string_field>hello</parameter>\n"
"<parameter=int_field>42</parameter>\n"
"<parameter=float_field>3.14</parameter>\n"
"<parameter=bool_field>true</parameter>\n"
"<parameter=null_field>null</parameter>\n"
'<parameter=array_field>["a", "b", "c"]</parameter>\n'
'<parameter=object_field>{"nested": "value"}</parameter>\n'
"</function>\n</tool_call>"
),
empty_arguments_output="<tool_call>\n<function=refresh>\n</function>\n</tool_call>",
surrounding_text_output=(
"Let me check the weather for you.\n\n"
"<tool_call>\n<function=get_weather>\n"
"<parameter=city>Tokyo</parameter>\n"
"</function>\n</tool_call>\n\n"
"I will get that information."
),
escaped_strings_output=(
"<tool_call>\n<function=test_function>\n"
'<parameter=quoted>He said "hello"</parameter>\n'
"<parameter=path>C:\\Users\\file.txt</parameter>\n"
"<parameter=newline>line1\nline2</parameter>\n"
"</function>\n</tool_call>"
),
malformed_input_outputs=[
"<tool_call><function=func>",
"<tool_call><function=></function></tool_call>",
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers - Qwen3XML has systematic streaming issues
xfail_streaming={
"test_single_tool_call_simple_args": (
"Qwen3XML streaming has systematic issues"
),
"test_parallel_tool_calls": "Qwen3XML streaming has systematic issues",
"test_various_data_types": "Qwen3XML streaming has systematic issues",
"test_empty_arguments": "Qwen3XML streaming has systematic issues",
"test_surrounding_text": "Qwen3XML streaming has systematic issues",
"test_escaped_strings": "Qwen3XML streaming has systematic issues",
"test_malformed_input": (
"Qwen3XML parser is lenient with malformed input"
),
"test_streaming_reconstruction": (
"Qwen3XML streaming reconstruction has known issues"
),
},
supports_typed_arguments=False,
)

View File

@@ -1,112 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
class TestStep3ToolParser(ToolParserTests):
@pytest.fixture(scope="class")
def tokenizer(self) -> TokenizerLike:
return get_tokenizer("stepfun-ai/step3")
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="step3",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
"<tool_calls_begin><tool_call_begin>"
'<steptml:invoke name="get_weather">'
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
"</steptml:invoke><tool_call_end><tool_calls_end>"
),
parallel_tool_calls_output=(
"<tool_calls_begin><tool_call_begin>"
'<steptml:invoke name="get_weather">'
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
"</steptml:invoke><tool_call_end><tool_sep>"
'<tool_call_begin><steptml:invoke name="get_time">'
'<steptml:parameter name="timezone">Asia/Tokyo</steptml:parameter>'
"</steptml:invoke><tool_call_end><tool_calls_end>"
),
various_data_types_output=(
"<tool_calls_begin><tool_call_begin>"
'<steptml:invoke name="test_function">'
'<steptml:parameter name="string_field">hello</steptml:parameter>'
'<steptml:parameter name="int_field">42</steptml:parameter>'
'<steptml:parameter name="float_field">3.14</steptml:parameter>'
'<steptml:parameter name="bool_field">true</steptml:parameter>'
'<steptml:parameter name="null_field">null</steptml:parameter>'
'<steptml:parameter name="array_field">'
'["a", "b", "c"]</steptml:parameter>'
'<steptml:parameter name="object_field">'
'{"nested": "value"}</steptml:parameter>'
"</steptml:invoke><tool_call_end><tool_calls_end>"
),
empty_arguments_output=(
"<tool_calls_begin><tool_call_begin>"
'<steptml:invoke name="refresh"></steptml:invoke>'
"<tool_call_end><tool_calls_end>"
),
surrounding_text_output=(
"Let me check the weather for you.\n\n"
"<tool_calls_begin><tool_call_begin>"
'<steptml:invoke name="get_weather">'
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
"</steptml:invoke><tool_call_end><tool_calls_end>\n\n"
"I'll get that information."
),
escaped_strings_output=(
"<tool_calls_begin><tool_call_begin>"
'<steptml:invoke name="test_function">'
'<steptml:parameter name="quoted">He said "hello"</steptml:parameter>'
'<steptml:parameter name="path">C:\\Users\\file.txt</steptml:parameter>'
'<steptml:parameter name="newline">line1\nline2</steptml:parameter>'
"</steptml:invoke><tool_call_end><tool_calls_end>"
),
malformed_input_outputs=[
(
"<tool_calls_begin><tool_call_begin>"
'<steptml:invoke name="func">'
),
(
'<tool_call_begin><steptml:invoke name="func">'
"</steptml:invoke><tool_call_end>"
),
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_nonstreaming={
"test_single_tool_call_simple_args": (
"Step3 parser non-streaming has bugs"
),
"test_parallel_tool_calls": ("Step3 parser non-streaming has bugs"),
"test_various_data_types": "Step3 parser non-streaming has bugs",
"test_empty_arguments": "Step3 parser non-streaming has bugs",
"test_surrounding_text": "Step3 parser non-streaming has bugs",
"test_escaped_strings": "Step3 parser non-streaming has bugs",
},
xfail_streaming={
"test_parallel_tool_calls": (
"Step3 parser has significant bugs in both streaming "
"and non-streaming"
),
"test_streaming_reconstruction": (
"Step3 parser non-streaming has bugs, so streaming "
"doesn't match non-streaming"
),
},
supports_typed_arguments=False,
)

View File

@@ -6,7 +6,6 @@ import pytest
from .utils import (
MESSAGES_WITHOUT_TOOLS,
SEED,
WEATHER_TOOL,
ServerConfig,
ensure_system_prompt,
@@ -28,7 +27,6 @@ async def test_chat_completion_without_tools(
max_completion_tokens=150,
model=model_name,
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
@@ -49,7 +47,6 @@ async def test_chat_completion_without_tools(
max_completion_tokens=150,
model=model_name,
logprobs=False,
seed=SEED,
stream=True,
)
chunks: list[str] = []
@@ -100,7 +97,6 @@ async def test_chat_completion_with_tools(
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
@@ -122,7 +118,6 @@ async def test_chat_completion_with_tools(
model=model_name,
logprobs=False,
tools=[WEATHER_TOOL],
seed=SEED,
stream=True,
)

View File

@@ -10,7 +10,6 @@ from .utils import (
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
SEARCH_TOOL,
SEED,
WEATHER_TOOL,
ServerConfig,
)
@@ -40,7 +39,6 @@ async def test_parallel_tool_calls(
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
@@ -78,7 +76,6 @@ async def test_parallel_tool_calls(
max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)
@@ -169,7 +166,6 @@ async def test_parallel_tool_calls_with_results(
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
@@ -188,7 +184,6 @@ async def test_parallel_tool_calls_with_results(
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)
@@ -234,7 +229,6 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
parallel_tool_calls=False,
)
@@ -253,7 +247,6 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
parallel_tool_calls=False,
stream=True,
)

View File

@@ -10,7 +10,6 @@ from .utils import (
MESSAGES_ASKING_FOR_TOOLS,
MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL,
SEED,
WEATHER_TOOL,
)
@@ -28,7 +27,6 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
@@ -73,7 +71,6 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
max_completion_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)
@@ -157,7 +154,6 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
@@ -175,7 +171,6 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)

View File

@@ -42,8 +42,6 @@ def ensure_system_prompt(
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
SEED = 42
ARGS: list[str] = [
"--enable-auto-tool-choice",
"--max-model-len",

View File

@@ -43,7 +43,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
@@ -158,24 +157,6 @@ def new_chunked_local_attention_spec(
)
def new_mamba_spec(
block_size=16,
shapes=((2, 512), (3, 32, 32)),
dtypes=(torch.float32, torch.float32),
num_speculative_blocks=2,
mamba_cache_mode="none",
page_size_padded=None,
):
return MambaSpec(
block_size=block_size,
shapes=shapes,
dtypes=dtypes,
page_size_padded=page_size_padded,
mamba_cache_mode=mamba_cache_mode,
num_speculative_blocks=num_speculative_blocks,
)
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
@@ -2029,28 +2010,6 @@ def test_auto_fit_max_model_len():
assert vllm_config.model_config.max_model_len > 0
def test_auto_fit_max_model_len_with_hybrid():
"""Test that auto-fit works with hybrid KV cache specs."""
# Create config with original_max_model_len=-1 to trigger auto-fit
model_config = ModelConfig(max_model_len=8192)
# Simulate the user passing -1 by setting original_max_model_len
model_config.original_max_model_len = -1
vllm_config = VllmConfig(model_config=model_config)
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer
gamma = 2
kv_cache_specs = {
"layer_1": new_mamba_spec(num_speculative_blocks=gamma),
"layer_2": new_kv_cache_spec(),
}
available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma)
_kv_cache_configs = get_kv_cache_configs(
vllm_config, [kv_cache_specs], [available_memory]
)
assert vllm_config.model_config.max_model_len == 1024
def test_auto_fit_max_model_len_not_triggered():
"""Test that auto-fit is not triggered when original_max_model_len is not -1."""
model_config = ModelConfig(max_model_len=16)

View File

@@ -12,7 +12,7 @@ import pytest
import pytest_asyncio
import requests
from tests.utils import ROCM_ENV_OVERRIDES, RemoteOpenAIServer
from tests.utils import RemoteOpenAIServer
from tests.v1.utils import check_request_balancing
from vllm.platforms import current_platform
@@ -27,84 +27,6 @@ TP_SIZE = int(os.getenv("TP_SIZE", "1"))
NUM_NODES = 2
async def _make_completion_request(
client: openai.AsyncOpenAI,
model_name: str,
) -> openai.types.Completion:
"""Make a single completion request and validate the response.
Uses temperature=1.0 to ensure diverse outputs across concurrent
requests for realistic load balancer testing.
"""
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=1.0,
)
assert completion.id is not None, (
f"Expected non-None completion id. usage={completion.usage!r}"
)
assert completion.choices is not None and len(completion.choices) == 1, (
f"Expected 1 choice, got "
f"{len(completion.choices) if completion.choices else 'None'}"
)
choice = completion.choices[0]
# With temperature=1.0, the model may emit a stop token immediately,
# producing empty text with finish_reason='stop'. This is valid
# model behavior - the test's purpose is load balancing, not output
# quality.
assert choice.finish_reason in ("length", "stop"), (
f"Expected finish_reason 'length' or 'stop', "
f"got {choice.finish_reason!r}. text={choice.text!r}"
)
if choice.finish_reason == "length":
assert len(choice.text) >= 1, (
f"Expected non-empty text with finish_reason='length', got {choice.text!r}"
)
assert completion.usage.prompt_tokens > 0, (
f"Expected positive prompt_tokens, got {completion.usage.prompt_tokens}"
)
assert completion.usage.total_tokens > 0, (
f"Expected positive total_tokens, got {completion.usage.total_tokens}"
)
return completion
async def _run_request_bursts(
client: openai.AsyncOpenAI,
model_name: str,
num_requests: int = 200,
num_bursts: int = 2,
):
"""Send multiple bursts of completion requests and validate all succeed."""
for burst in range(num_bursts):
all_tasks = []
for _ in range(num_requests):
all_tasks.append(
asyncio.create_task(_make_completion_request(client, model_name))
)
await asyncio.sleep(0.01)
results = await asyncio.gather(*all_tasks, return_exceptions=True)
assert len(results) == num_requests, (
f"Burst {burst}: expected {num_requests} results, got {len(results)}"
)
for result in results:
if isinstance(result, BaseException):
raise result
assert all(completion is not None for completion in results), (
f"Burst {burst}: some completions were None"
)
await asyncio.sleep(0.5)
class MultinodeInternalLBServerManager:
"""Manages multi-node data parallel vLLM server instances for internal
load balancer testing using --headless mode."""
@@ -186,7 +108,6 @@ class MultinodeInternalLBServerManager:
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
**ROCM_ENV_OVERRIDES,
current_platform.device_control_env_var: ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(r, r + gpus_per_node)
@@ -308,7 +229,6 @@ class APIOnlyServerManager:
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
**ROCM_ENV_OVERRIDES,
# No GPUs needed for API-only server
},
)
@@ -329,11 +249,10 @@ class APIOnlyServerManager:
engines_server_args,
auto_port=False,
env_dict={
**ROCM_ENV_OVERRIDES,
current_platform.device_control_env_var: ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(self.dp_size * self.tp_size)
),
)
},
)
server.__enter__()
@@ -476,15 +395,58 @@ async def test_multinode_dp_completion(
servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str,
) -> None:
async def make_request():
completion = await client.completions.create(
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0
)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request
result = await _make_completion_request(client, model_name)
result = await make_request()
assert result is not None
print("Multi-node internal LB handled single completion request successfully")
await asyncio.sleep(0.5)
# Send multiple bursts - internal LB should distribute across DP ranks
await _run_request_bursts(client, model_name)
# Send multiple requests - internal LB should distribute across DP ranks
num_requests = 200
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
_, server_args = servers[0]
api_server_count = (
@@ -608,16 +570,59 @@ async def test_api_only_multinode_dp_completion(
) -> None:
"""Test API-only server with all engines on separate headless server."""
async def make_request():
completion = await api_only_client.completions.create(
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0
)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes
# early or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request
result = await _make_completion_request(api_only_client, model_name)
result = await make_request()
assert result is not None
print("API-only server handled single completion request successfully")
await asyncio.sleep(0.5)
# Send multiple bursts - should be distributed across engines on
# Send multiple requests - should be distributed across engines on
# headless server
await _run_request_bursts(api_only_client, model_name)
num_requests = 200
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
api_server, api_server_args = api_only_servers[0]
api_server_count = (

View File

@@ -307,6 +307,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
num_submods = len(submod_names)
num_artifacts = standalone_compile_artifacts.num_artifacts()
logger.info(
"reconstructing serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
with functorch_ctx:
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
@@ -317,10 +324,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
)
logger.info(
"reconstructed serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
"reconstructed serializable fn from standalone compile artifacts"
)
return fn

View File

@@ -138,6 +138,13 @@ class ParallelConfig:
"""Whether the deployed model is MoE (if known)."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_ep_weight_filter: bool = False
"""Skip non-local expert weights during model loading when expert
parallelism is active. Each rank only reads its own expert shard from
disk, which can drastically reduce storage I/O for MoE models with
per-expert weight tensors (e.g. DeepSeek, Mixtral, Kimi-K2.5). Has no
effect on 3D fused-expert checkpoints (e.g. GPT-OSS) or non-MoE
models."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)

View File

@@ -419,6 +419,7 @@ class EngineArgs:
data_parallel_external_lb: bool = False
data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_ep_weight_filter: bool = ParallelConfig.enable_ep_weight_filter
moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
@@ -901,6 +902,10 @@ class EngineArgs:
"-ep",
**parallel_kwargs["enable_expert_parallel"],
)
parallel_group.add_argument(
"--enable-ep-weight-filter",
**parallel_kwargs["enable_ep_weight_filter"],
)
parallel_group.add_argument(
"--all2all-backend", **parallel_kwargs["all2all_backend"]
)
@@ -1727,6 +1732,7 @@ class EngineArgs:
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel,
enable_ep_weight_filter=self.enable_ep_weight_filter,
all2all_backend=self.all2all_backend,
enable_elastic_ep=self.enable_elastic_ep,
enable_dbo=self.enable_dbo,

View File

@@ -310,14 +310,11 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
)
else:
if not request.include_reasoning:
reasoning_ended = True
elif reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
prompt_token_ids or []
)
else:
reasoning_ended = None
reasoning_ended = (
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
generator = self.engine_client.generate(
engine_prompt,

View File

@@ -296,16 +296,6 @@ def use_aot_compile() -> bool:
)
def use_mega_aot_artifact():
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
"1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0"
)
return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1"
def env_with_choices(
env_name: str,
default: str | None,
@@ -626,7 +616,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Enable loading compiled models directly from cached standalone compile artifacts
# without re-splitting graph modules. This reduces overhead during model
# loading by using reconstruct_serializable_fn_from_mega_artifact.
"VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact,
"VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get(
"VLLM_USE_MEGA_AOT_ARTIFACT", "0"
)
== "1",
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),

View File

@@ -9,7 +9,6 @@ from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
@@ -156,9 +155,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
if type(source_layer) is ColumnParallelLinear:
return True
if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear):
if type(source_layer) is MergedColumnParallelLinear:
if len(packed_modules_list) != 1:
return False
# Exclude layers with 3+ output sizes - those are handled by
@@ -607,7 +606,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA(
) -> bool:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear):
if type(source_layer) is not MergedColumnParallelLinear:
return False
# If packed_modules_list has 3+ items, use this class

View File

@@ -7,7 +7,6 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import ReplicatedLinear
from .base_linear import BaseLinearLayerWithLoRA
@@ -56,7 +55,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear)
return type(source_layer) is ReplicatedLinear
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]

View File

@@ -11,7 +11,6 @@ from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform
@@ -90,7 +89,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)
return type(source_layer) is RowParallelLinear
# The following layer is based on the tensor parallelism strategy given in

View File

@@ -7,7 +7,6 @@ import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
@@ -133,7 +132,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding)
return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):

View File

@@ -22,11 +22,10 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
def maybe_get_oot_by_class(class_type: type) -> type:
class_name = class_type.__name__
def get_oot_class_by_name(class_name: str) -> type | None:
if class_name in op_registry_oot:
return op_registry_oot[class_name]
return class_type
return None
class PluggableLayer(nn.Module):

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: np.ndarray,
device: torch.device,
) -> torch.Tensor | None:
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
if attn_backend != AttentionBackendEnum.FLASHINFER:
@@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp):
tp_size: int,
device: torch.device,
) -> torch.Tensor:
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
attn_backend, cu_seqlens, hidden_size, tp_size, device
)

View File

@@ -23,8 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
kMxfp8Dynamic,
kMxfp8Static,
)
from vllm.platforms import current_platform
@@ -69,54 +67,11 @@ class TrtLlmFp8ExpertsBase:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kMxfp8Static, kMxfp8Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) in [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kMxfp8Static, kMxfp8Dynamic),
]:
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
@@ -158,10 +113,9 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 block and MXFP8."""
"""Supports Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kMxfp8Static, kMxfp8Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@@ -205,7 +159,6 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
apply_router_weight_on_input: bool,
):
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
# Pack topk_ids and topk_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
@@ -222,16 +175,6 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
assert a1q_scale is not None
is_mxfp8 = self.quant_config.block_shape == [1, 32]
if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True
hidden_states_scale = a1q_scale
else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False
hidden_states_scale = a1q_scale.t().contiguous()
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
# output tensor in-place so we need to manually copy the result to the
# output tensor
@@ -240,7 +183,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
@@ -254,9 +197,8 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
use_shuffled_weight=use_shuffled_weight,
use_shuffled_weight=False,
weight_layout=0,
fp8_quantization_type=fp8_quant_type,
# output=output,
)
output.copy_(result)
@@ -298,11 +240,10 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kMxfp8Static, kMxfp8Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@@ -315,10 +256,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) in [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kMxfp8Static, kMxfp8Dynamic),
]:
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
@@ -336,7 +274,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
else:
raise ValueError("Unsupported quantization scheme.")
def _apply_block_scale(
def _apply_per_block(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -353,38 +291,32 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
# Delay import for non-CUDA.
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU
assert self.topk <= global_num_experts
assert self.topk <= 10
assert global_num_experts % 4 == 0
assert self.quant_config.block_shape in [[128, 128], [1, 32]]
# Kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# TODO: fuse into the quant kernel.
assert a1q_scale is not None
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
is_mxfp8 = self.quant_config.block_shape == [1, 32]
if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True
hidden_states_scale = a1q_scale
else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False
hidden_states_scale = a1q_scale.t().contiguous()
assert self.topk <= global_num_experts
assert self.topk <= 10
assert global_num_experts % 4 == 0
assert self.quant_config.block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# Kernel requires transposed hidden state scales
# TODO: fuse into the quant kernel.
assert a1q_scale is not None
a1q_scale_t = a1q_scale.t().contiguous()
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
hidden_states_scale=a1q_scale_t,
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
@@ -398,8 +330,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
use_shuffled_weight=use_shuffled_weight,
fp8_quantization_type=fp8_quant_type,
use_shuffled_weight=False,
)
def _apply_per_tensor(
@@ -478,7 +409,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
topk_group: int | None = None,
) -> torch.Tensor:
if self.quant_config.block_shape is not None:
return self._apply_block_scale(
return self._apply_per_block(
hidden_states,
w1,
w2,
@@ -510,6 +441,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
)
else:
raise NotImplementedError(
"Only per-block, per-tensor, and MXFP8 quantization are "
f"supported in {self.__class__.__name__}."
"Only per-block and per-tensor quantization are supported in "
f"{self.__class__.__name__}."
)

View File

@@ -444,7 +444,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_fi(
w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
layer=layer,
w13=w13,
w2=w2,
@@ -512,21 +512,6 @@ def make_fp8_moe_quant_config(
g1_alphas=(w1_scale * a1_scale).squeeze(),
g2_alphas=(w2_scale * a2_scale).squeeze(),
)
# MXFP8 uses "mxfp8" quant_dtype so the prepare step dispatches to
# _mxfp8_e4m3_quantize rather than standard FP8 block quantization.
# Non-swizzled layout is required since the TRTLLM kernel expects
# scales in (num_tokens, hidden_dim // 32) format.
if block_shape == [1, 32]:
return FusedMoEQuantConfig.make(
"mxfp8",
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
is_nvfp4_scale_swizzled=False,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,

View File

@@ -1,87 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
backend_to_kernel_cls,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kMxfp8Dynamic,
kMxfp8Static,
)
logger = init_logger(__name__)
_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset(
{
Fp8MoeBackend.FLASHINFER_TRTLLM,
}
)
_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = {
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
}
def _select_kernel_cls(
backend: Fp8MoeBackend,
config: FusedMoEConfig,
) -> type[mk.FusedMoEExperts]:
"""Select the first supported expert class for the MXFP8 config."""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
last_reason: str | None = None
for cls in backend_to_kernel_cls(backend):
supported, reason = cls.is_supported_config(
cls,
config,
kMxfp8Static,
kMxfp8Dynamic,
activation_format,
)
if supported:
return cls
last_reason = reason
raise ValueError(
f"No supported MXFP8 expert class for {backend.value}: {last_reason}"
)
class MxFp8MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
def select_mxfp8_moe_backend(
config: FusedMoEConfig,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
"""Select the MXFP8 MoE backend and the best expert class.
Returns:
A tuple of (fp8_backend, experts_cls).
"""
) -> MxFp8MoeBackend:
if config.is_lora_enabled:
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")
AVAILABLE_BACKENDS = [
MxFp8MoeBackend.FLASHINFER_TRTLLM,
]
runner_backend = config.moe_backend
if runner_backend != "auto":
backend = _BACKEND_NAME_MAP.get(runner_backend)
if backend is None:
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for "
f"MXFP8 MoE. Expected one of "
f"{list(_BACKEND_NAME_MAP.keys())}."
mapping = {
"flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM,
}
if backend := mapping.get(runner_backend):
logger.info_once(
"Using '%s' MxFp8 MoE backend (user-requested).",
backend.value,
)
logger.info_once(
"Using '%s' MxFp8 MoE backend (user-requested).",
backend.value,
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. "
f"Expected one of {list(mapping.keys())}."
)
return backend, _select_kernel_cls(backend, config)
# Auto-select: pick the first supported backend.
for backend in _SUPPORTED_BACKENDS:
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend, _select_kernel_cls(backend, config)
raise ValueError("No MXFP8 MoE backends available.")
# Auto-select: only one backend available for now.
backend = AVAILABLE_BACKENDS[0]
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend

View File

@@ -199,7 +199,7 @@ def _mxfp8_e4m3_quantize(
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None or block_shape == [1, 32]
assert block_shape is None
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)

View File

@@ -31,7 +31,6 @@ QuantizationMethods = Literal[
"torchao",
"inc",
"mxfp4",
"mxfp8",
"petit_nvfp4",
"cpu_awq",
]
@@ -130,7 +129,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
)
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .mxfp8 import Mxfp8Config
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .torchao import TorchAOConfig
@@ -158,7 +156,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": INCConfig,
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"mxfp8": Mxfp8Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
}

View File

@@ -25,13 +25,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
MxFp8MoeBackend,
select_mxfp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
@@ -1712,7 +1712,8 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
self.quant_config = quant_config
assert self.quant_config.is_checkpoint_mxfp8_serialized
self.mxfp8_backend, _ = select_mxfp8_moe_backend(self.moe)
# Select MXFP8 MoE backend
self.mxfp8_backend = select_mxfp8_moe_backend(self.moe)
def create_weights(
self,
@@ -1942,7 +1943,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
@property
def is_monolithic(self) -> bool:
return self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
@@ -1955,7 +1956,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
Fp8QuantizationType,
)
assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
if layer.enable_eplb:
raise NotImplementedError(

View File

@@ -1,354 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Online MXFP8 (microscaling FP8, block-32) quantization config and methods."""
from typing import Any
import torch
from torch.nn import Module
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
select_mxfp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8OnlineLinearMethod,
Fp8OnlineMoEMethod,
_copy_missing_attrs,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
mxfp8_e4m3_quantize,
swizzle_mxfp8_scale,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.model_loader.weight_utils import (
initialize_single_dummy_weight,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
class Mxfp8Config(Fp8Config):
"""Config class for online MXFP8 MoE quantization."""
def __init__(
self,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
) -> None:
if activation_scheme != "dynamic":
raise ValueError("mxfp8 only supports dynamic activation scheme.")
super().__init__(
is_checkpoint_fp8_serialized=False,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=None,
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp8"
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
activation_scheme = cls.get_from_keys_or(
config, ["activation_scheme"], "dynamic"
)
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedLinearMethod()
return Mxfp8OnlineLinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Mxfp8OnlineMoEMethod(self, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
"""Online MXFP8 linear method.
Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling
FP8 with block-32 scales) during weight loading.
Args:
quant_config: The MXFP8 quantization config.
"""
uses_meta_device: bool = True
def __init__(self, quant_config: "Mxfp8Config"):
self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype()
self.mxfp8_linear = Mxfp8LinearOp(self._select_backend())
logger.info_once(
"Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value
)
@staticmethod
def _select_backend() -> Mxfp8LinearBackend:
try:
from vllm.utils import flashinfer as fi
_ = fi.mm_mxfp8
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
except Exception:
logger.warning(
"FlashInfer mm_mxfp8 not available, "
"falling back to MXFP8 emulation backend."
)
return Mxfp8LinearBackend.EMULATION
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
raise ValueError(
f"MXFP8 requires input_size_per_partition "
f"({input_size_per_partition}) to be divisible by "
f"{MXFP8_BLOCK_SIZE}."
)
super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
initialize_single_dummy_weight(layer.weight)
weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())
if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
N, K = layer.weight.shape[0], layer.weight.shape[1]
weight_scale = swizzle_mxfp8_scale(weight_scale, N, K)
layer.input_scale = None
replace_parameter(layer, "weight", weight_fp8.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.mxfp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
bias=bias,
)
class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
"""MoE method for online MXFP8 (block) quantization."""
uses_meta_device: bool = True
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
FusedMoEMethodBase.__init__(self, layer.moe_config)
self.quant_config = quant_config
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
self.weight_block_size = [1, MXFP8_BLOCK_SIZE]
self.block_quant = True
self.weight_scale_name = "weight_scale"
self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe)
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if (
hidden_size % MXFP8_BLOCK_SIZE != 0
or intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0
):
raise ValueError(
"Online MXFP8 MoE requires hidden/intermediate sizes divisible "
f"by {MXFP8_BLOCK_SIZE}."
)
super().create_weights(
layer=layer,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
params_dtype=params_dtype,
**extra_weight_attrs,
)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // MXFP8_BLOCK_SIZE,
dtype=torch.uint8,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.weight_block_size = [1, MXFP8_BLOCK_SIZE]
def _quantize_mxfp8_moe_weight(
self, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
num_batches = weight.size(0)
w_quant = []
w_scales = []
for i in range(num_batches):
mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize(
weight[i], is_sf_swizzled_layout=False
)
w_quant.append(mx_fp8_quant)
w_scales.append(mx_fp8_scale)
return torch.stack(w_quant), torch.stack(w_scales)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight)
w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight)
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
)
layer._already_called_process_weights_after_loading = True

View File

@@ -305,81 +305,6 @@ def align_fp8_moe_weights_for_fi(
return padded_w13, padded_w2, padded_intermediate
def _shuffle_mxfp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
is_gated: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel.
Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py:
1. reorder_rows_for_gated_act_gemm (interleave gate/up rows)
2. shuffle_matrix_a (weight data layout shuffle)
3. shuffle_matrix_sf_a (scale factor layout shuffle)
"""
from flashinfer import (
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
epilogue_tile_m = 128
num_experts = w13.shape[0]
intermediate_size = w13.shape[1] // 2
hidden_size = w13.shape[2]
w13_interleaved: list[torch.Tensor] = []
w13_scale_interleaved: list[torch.Tensor] = []
for i in range(num_experts):
if is_gated:
w13_interleaved.append(
reorder_rows_for_gated_act_gemm(
w13[i].reshape(2 * intermediate_size, -1)
)
)
w13_scale_interleaved.append(
reorder_rows_for_gated_act_gemm(
w13_scale[i].reshape(2 * intermediate_size, -1)
)
)
else:
w13_interleaved.append(w13[i])
w13_scale_interleaved.append(w13_scale[i])
w13_shuffled: list[torch.Tensor] = []
w2_shuffled: list[torch.Tensor] = []
w13_scale_shuffled: list[torch.Tensor] = []
w2_scale_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
w13_shuffled.append(
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
)
w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m))
w13_scale_shuffled.append(
shuffle_matrix_sf_a(
w13_scale_interleaved[i]
.view(torch.uint8)
.reshape(2 * intermediate_size, -1),
epilogue_tile_m,
)
)
w2_scale_shuffled.append(
shuffle_matrix_sf_a(
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
epilogue_tile_m,
)
)
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape)
w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape)
return w13_out, w2_out, w13_scale_out, w2_scale_out
def prepare_fp8_moe_layer_for_fi(
layer: torch.nn.Module,
w13: torch.Tensor,
@@ -389,7 +314,7 @@ def prepare_fp8_moe_layer_for_fi(
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor | None,
is_trtllm: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert Fp8 MoE weights to flashinfer kernel format
@@ -404,33 +329,10 @@ def prepare_fp8_moe_layer_for_fi(
block_quant = (
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
is_gated = layer.activation.is_gated
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
if is_mxfp8 and is_trtllm:
# FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores
# [gate; up]. Swap both weights and scales before interleaving.
if layer.moe_config.is_act_and_mul:
w13 = swap_w13_to_w31(w13)
# Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight;
# reshape to 3D so swap_w13_to_w31 can flip the two halves,
# then flatten back.
if w13_scale.ndim == 2:
num_rows = w13.shape[1] # 2 * intermediate_size
w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1)
w13_scale = swap_w13_to_w31(w13_scale)
w13_scale = w13_scale.reshape(w13_scale.shape[0], -1)
else:
w13_scale = swap_w13_to_w31(w13_scale)
w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights(
w13, w2, w13_scale, w2_scale, is_gated
)
return w13, w2, w13_scale, w2_scale
# Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this.
is_gated = layer.activation.is_gated
if not block_quant:
min_alignment = 16 if is_gated else 128
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
@@ -467,4 +369,4 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
return w13, w2, w13_scale, w2_scale
return w13, w2, w13_scale

View File

@@ -149,12 +149,6 @@ kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128))
kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True)
kMxfp8StaticScale = ScaleDesc(torch.uint8, True, GroupShape(1, 32))
kMxfp8Static = QuantKey(FP8_DTYPE, kMxfp8StaticScale, symmetric=True)
kMxfp8DynamicScale = ScaleDesc(torch.uint8, False, GroupShape(1, 32))
kMxfp8Dynamic = QuantKey(FP8_DTYPE, kMxfp8DynamicScale, symmetric=True)
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)

View File

@@ -313,7 +313,18 @@ class DefaultModelLoader(BaseModelLoader):
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
if not (model_config.is_moe and parallel_config.enable_expert_parallel):
if not (
model_config.is_moe
and parallel_config.enable_expert_parallel
and parallel_config.enable_ep_weight_filter
):
return
# When EPLB is enabled, redundant physical expert slots may map to
# logical experts that belong to other ranks in the default partition.
# The weight loader needs to see ALL logical expert weights so it can
# populate these redundant slots. Skip the filter entirely.
if parallel_config.enable_eplb:
return
num_experts = model_config.get_num_experts()

View File

@@ -73,4 +73,9 @@ def should_skip_weight(
if eid is None:
# Not an expert weight (dense / shared-expert / embedding) → keep.
return False
# Only skip heavy weight tensors, never scale/metadata tensors.
# Scale tensors are tiny and some backends need them from ALL experts
# (e.g. FlashInfer NVFP4 computes a global max of activation scales).
if not weight_name.endswith(".weight"):
return False
return eid not in local_expert_ids

View File

@@ -74,7 +74,6 @@ class EagleMistralLarge3Model(DeepseekV2Model):
prefix=maybe_prefix(prefix, "fc"),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.aux_hidden_state_layers: tuple[int, ...] = ()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)

View File

@@ -246,6 +246,7 @@ class CpuPlatform(Platform):
"size_asserts": False,
"nan_asserts": False,
"epilogue_fusion": True,
"cpp.dynamic_threads": True,
}
)

View File

@@ -15,15 +15,8 @@ from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
Tokenizer,
)
from mistral_common.tokens.tokenizers.instruct import (
InstructTokenizerBase,
InstructTokenizerV13,
)
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as MistralCommonTokenizer,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
@@ -33,20 +26,21 @@ from pydantic import ValidationError
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
from .protocol import TokenizerLike
if TYPE_CHECKING:
from transformers import BatchEncoding
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
logger = init_logger(__name__)
@@ -241,6 +235,15 @@ class MistralTokenizer(TokenizerLike):
download_dir: str | None = None,
**kwargs,
) -> "MistralTokenizer":
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
tokenizer = MistralCommonBackend.from_pretrained(
path_or_repo_id,
*args,
@@ -252,13 +255,13 @@ class MistralTokenizer(TokenizerLike):
return cls(tokenizer)
def __init__(self, tokenizer: MistralCommonBackend) -> None:
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
super().__init__()
self.transformers_tokenizer: MistralCommonBackend = tokenizer
self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
self.tokenizer: Tokenizer = self.instruct.tokenizer
self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
self.tokenizer = self.instruct.tokenizer
mode = self.mistral._chat_completion_request_validator._mode
if mode != ValidationMode.test:
@@ -480,11 +483,7 @@ class MistralTokenizer(TokenizerLike):
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
to_decode_special_tokens = {
SpecialTokens.tool_calls,
SpecialTokens.begin_think,
SpecialTokens.end_think,
}
to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [

View File

@@ -241,10 +241,7 @@ class MistralToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
has_bot_token = (
self.bot_token_id in current_token_ids or self.bot_token in current_text
)
if not has_bot_token:
if self.bot_token_id not in current_token_ids:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return DeltaMessage(content=delta_text)
@@ -278,8 +275,7 @@ class MistralToolParser(ToolParser):
additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call
if self.bot_token not in delta_text:
return DeltaMessage(content=delta_text)
assert self.bot_token_id in delta_token_ids
if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join(
@@ -415,7 +411,7 @@ class MistralToolParser(ToolParser):
index=self.current_tool_id, type="function"
)
current_tool_call_modified = False
if self.bot_token_id in delta_token_ids or self.bot_token in delta_text:
if self.bot_token_id in delta_token_ids:
# this is the first tool call
if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0]

View File

@@ -96,13 +96,8 @@ def _trtllm_prefill_attn_kvfp8_dequant(
mock_kv_cache_ptr,
k_scale_ptr,
v_scale_ptr,
src_stride_page,
src_stride_kv,
src_stride_head,
DST_K_CACHE_STRIDE: tl.constexpr,
DST_KV_CACHE_STRIDE: tl.constexpr,
HEAD_STRIDE: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
K_CACHE_STRIDE: tl.constexpr,
KV_CACHE_STRIDE: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
mock_block_table_idx = tl.program_id(1).to(tl.int64)
@@ -113,42 +108,31 @@ def _trtllm_prefill_attn_kvfp8_dequant(
return
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty
# Dequantize K
k_scale_val = tl.load(k_scale_ptr)
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
mock_cache_offset = (
batch_idx * block_table_stride + mock_block_table_idx + 1
) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
# Dequantize V
v_scale_val = tl.load(v_scale_ptr)
mock_page_idx = batch_idx * block_table_stride + mock_block_table_idx + 1
head_offsets = tl.arange(0, HEAD_STRIDE)
for h in range(NUM_KV_HEADS):
h_off = tl.cast(h, tl.int64)
# Read K from source (supports non-contiguous page/kv/head strides)
src_k = orig_page_num * src_stride_page + h_off * src_stride_head + head_offsets
fp8_k = tl.load(kv_cache_ptr + src_k)
dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype)
# Write K to contiguous mock cache
dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + head_offsets
tl.store(mock_kv_cache_ptr + dst_k, dequant_k)
# Read V from source (offset by src_stride_kv for the V half)
src_v = (
orig_page_num * src_stride_page
+ src_stride_kv
+ h_off * src_stride_head
+ head_offsets
)
fp8_v = tl.load(kv_cache_ptr + src_v)
dequant_v = (fp8_v.to(tl.float32) * v_scale_val).to(dequant_dtype)
# Write V to contiguous mock cache
dst_v = (
mock_page_idx * DST_KV_CACHE_STRIDE
+ DST_K_CACHE_STRIDE
+ h * HEAD_STRIDE
+ head_offsets
)
tl.store(mock_kv_cache_ptr + dst_v, dequant_v)
offset = (
orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
)
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
mock_cache_offset = (
(batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
+ K_CACHE_STRIDE
+ tl.arange(0, K_CACHE_STRIDE)
)
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
def trtllm_prefill_attn_kvfp8_dequant(
@@ -162,18 +146,8 @@ def trtllm_prefill_attn_kvfp8_dequant(
s = kv_cache.shape
assert s[1] == 2
assert dequant_dtype in (torch.bfloat16, torch.float16)
num_kv_heads, block_size, head_size = s[2], s[3], s[4]
head_stride = block_size * head_size
k_cache_stride = num_kv_heads * head_stride
k_cache_stride = s[2] * s[3] * s[4]
kv_cache_stride = k_cache_stride * s[1]
strides = kv_cache.stride()
assert strides[3] == head_size and strides[4] == 1, (
"For kv cache layouts, (block_size, head_size) "
f"dimensions must be contiguous, got strides {strides}"
)
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
# mock kv cache contains just the pages needed by this prefill
mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
@@ -192,13 +166,8 @@ def trtllm_prefill_attn_kvfp8_dequant(
mock_kv_cache,
k_scale,
v_scale,
strides[0],
strides[1],
strides[2],
k_cache_stride,
kv_cache_stride,
head_stride,
num_kv_heads,
)
return mock_kv_cache, mock_block_table

View File

@@ -1356,10 +1356,8 @@ def _max_memory_usage_bytes_from_groups(
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
blocks_needed = sum(
cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)
for group in kv_cache_groups
)
any_spec = kv_cache_groups[0].kv_cache_spec
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
return group_size * page_size * blocks_needed