Compare commits
18 Commits
v0.18.0rc0
...
v0.17.2rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54a62a79f7 | ||
|
|
384dc7f77b | ||
|
|
f04d5226f8 | ||
|
|
0a0a1a198b | ||
|
|
6c1cfbad32 | ||
|
|
45f526d652 | ||
|
|
5db91f0aaf | ||
|
|
061980c36a | ||
|
|
7a49742b88 | ||
|
|
3e6a1e1686 | ||
|
|
7961486a9b | ||
|
|
4f9b14c21c | ||
|
|
31a458c091 | ||
|
|
a3a51d20e7 | ||
|
|
e5b807607c | ||
|
|
fd4d96302a | ||
|
|
c0f011918d | ||
|
|
e6ae4b1be1 |
@@ -333,15 +333,15 @@ apply_rocm_test_overrides() {
|
||||
# --- Entrypoint ignores ---
|
||||
if [[ $cmds == *" entrypoints/openai "* ]]; then
|
||||
cmds=${cmds//" entrypoints/openai "/" entrypoints/openai \
|
||||
--ignore=entrypoints/openai/test_audio.py \
|
||||
--ignore=entrypoints/openai/test_shutdown.py \
|
||||
--ignore=entrypoints/openai/chat_completion/test_audio.py \
|
||||
--ignore=entrypoints/openai/completion/test_shutdown.py \
|
||||
--ignore=entrypoints/openai/test_completion.py \
|
||||
--ignore=entrypoints/openai/test_models.py \
|
||||
--ignore=entrypoints/openai/test_lora_adapters.py \
|
||||
--ignore=entrypoints/openai/test_return_tokens_as_ids.py \
|
||||
--ignore=entrypoints/openai/test_root_path.py \
|
||||
--ignore=entrypoints/openai/chat_completion/test_root_path.py \
|
||||
--ignore=entrypoints/openai/test_tokenization.py \
|
||||
--ignore=entrypoints/openai/test_prompt_validation.py "}
|
||||
--ignore=entrypoints/openai/completion/test_prompt_validation.py "}
|
||||
fi
|
||||
|
||||
if [[ $cmds == *" entrypoints/llm "* ]]; then
|
||||
|
||||
@@ -162,7 +162,7 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (API Server 2)
|
||||
@@ -674,12 +674,12 @@ steps:
|
||||
- vllm/config/model.py
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Benchmarks # 11min
|
||||
timeout_in_minutes: 20
|
||||
@@ -1143,7 +1143,7 @@ steps:
|
||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py
|
||||
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py
|
||||
- pytest -v -s models/test_oot_registration.py
|
||||
- pytest -v -s plugins/lora_resolvers
|
||||
|
||||
@@ -1502,7 +1502,7 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (API Server 2)
|
||||
@@ -2133,12 +2133,12 @@ steps:
|
||||
- vllm/config/model.py
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Benchmarks # 11min
|
||||
timeout_in_minutes: 20
|
||||
@@ -2735,7 +2735,7 @@ steps:
|
||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
@@ -3257,7 +3257,7 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (API Server 2)
|
||||
@@ -3872,12 +3872,12 @@ steps:
|
||||
- vllm/config/model.py
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Benchmarks # 11min
|
||||
timeout_in_minutes: 20
|
||||
@@ -4508,7 +4508,7 @@ steps:
|
||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/chat_completion/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/chat_completion/test_oot_registration.py --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/tool_parsers/ --ignore=entrypoints/openai/responses
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
mirror:
|
||||
amd:
|
||||
|
||||
@@ -9,9 +9,9 @@ steps:
|
||||
- vllm/config/model.py
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
|
||||
@@ -36,6 +36,6 @@ steps:
|
||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai/chat_completion/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
2
.github/mergify.yml
vendored
2
.github/mergify.yml
vendored
@@ -381,7 +381,7 @@ pull_request_rules:
|
||||
- or:
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
||||
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- files~=^tests/entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
- files~=^tests/model_executor/model_loader/tensorizer_loader/
|
||||
actions:
|
||||
assign:
|
||||
|
||||
@@ -47,6 +47,8 @@ from common import (
|
||||
is_mla_backend,
|
||||
)
|
||||
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
|
||||
def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
"""Run standard attention benchmark (Flash/Triton/FlashInfer)."""
|
||||
@@ -462,7 +464,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--batch-specs",
|
||||
nargs="+",
|
||||
default=["q2k", "8q1s1k"],
|
||||
default=None,
|
||||
help="Batch specifications using extended grammar",
|
||||
)
|
||||
|
||||
@@ -478,6 +480,21 @@ def main():
|
||||
parser.add_argument("--repeats", type=int, default=1, help="Repetitions")
|
||||
parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations")
|
||||
parser.add_argument("--profile-memory", action="store_true", help="Profile memory")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
default="auto",
|
||||
choices=["auto", "fp8"],
|
||||
help="KV cache dtype: auto or fp8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda-graphs",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help=(
|
||||
"Launch kernels with CUDA graphs to eliminate CPU overhead"
|
||||
"in measurements (default: True)"
|
||||
),
|
||||
)
|
||||
|
||||
# Parameter sweep (use YAML config for advanced sweeps)
|
||||
parser.add_argument(
|
||||
@@ -536,21 +553,24 @@ def main():
|
||||
|
||||
# Batch specs and sizes
|
||||
# Support both explicit batch_specs and generated batch_spec_ranges
|
||||
if "batch_spec_ranges" in yaml_config:
|
||||
# Generate batch specs from ranges
|
||||
generated_specs = generate_batch_specs_from_ranges(
|
||||
yaml_config["batch_spec_ranges"]
|
||||
)
|
||||
# Combine with any explicit batch_specs
|
||||
if "batch_specs" in yaml_config:
|
||||
args.batch_specs = yaml_config["batch_specs"] + generated_specs
|
||||
else:
|
||||
args.batch_specs = generated_specs
|
||||
console.print(
|
||||
f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]"
|
||||
)
|
||||
elif "batch_specs" in yaml_config:
|
||||
args.batch_specs = yaml_config["batch_specs"]
|
||||
# CLI --batch-specs takes precedence over YAML when provided.
|
||||
cli_batch_specs_provided = args.batch_specs is not None
|
||||
if not cli_batch_specs_provided:
|
||||
if "batch_spec_ranges" in yaml_config:
|
||||
# Generate batch specs from ranges
|
||||
generated_specs = generate_batch_specs_from_ranges(
|
||||
yaml_config["batch_spec_ranges"]
|
||||
)
|
||||
# Combine with any explicit batch_specs
|
||||
if "batch_specs" in yaml_config:
|
||||
args.batch_specs = yaml_config["batch_specs"] + generated_specs
|
||||
else:
|
||||
args.batch_specs = generated_specs
|
||||
console.print(
|
||||
f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]"
|
||||
)
|
||||
elif "batch_specs" in yaml_config:
|
||||
args.batch_specs = yaml_config["batch_specs"]
|
||||
|
||||
if "batch_sizes" in yaml_config:
|
||||
args.batch_sizes = yaml_config["batch_sizes"]
|
||||
@@ -575,6 +595,10 @@ def main():
|
||||
args.warmup_iters = yaml_config["warmup_iters"]
|
||||
if "profile_memory" in yaml_config:
|
||||
args.profile_memory = yaml_config["profile_memory"]
|
||||
if "kv_cache_dtype" in yaml_config:
|
||||
args.kv_cache_dtype = yaml_config["kv_cache_dtype"]
|
||||
if "cuda_graphs" in yaml_config:
|
||||
args.cuda_graphs = yaml_config["cuda_graphs"]
|
||||
|
||||
# Parameter sweep configuration
|
||||
if "parameter_sweep" in yaml_config:
|
||||
@@ -629,12 +653,18 @@ def main():
|
||||
# Determine backends
|
||||
backends = args.backends or ([args.backend] if args.backend else ["flash"])
|
||||
prefill_backends = getattr(args, "prefill_backends", None)
|
||||
if not args.batch_specs:
|
||||
args.batch_specs = ["q2k", "8q1s1k"]
|
||||
console.print(f"Backends: {', '.join(backends)}")
|
||||
if prefill_backends:
|
||||
console.print(f"Prefill backends: {', '.join(prefill_backends)}")
|
||||
console.print(f"Batch specs: {', '.join(args.batch_specs)}")
|
||||
console.print(f"KV cache dtype: {args.kv_cache_dtype}")
|
||||
console.print(f"CUDA graphs: {args.cuda_graphs}")
|
||||
console.print()
|
||||
|
||||
init_workspace_manager(args.device)
|
||||
|
||||
# Run benchmarks
|
||||
all_results = []
|
||||
|
||||
@@ -687,6 +717,8 @@ def main():
|
||||
repeats=args.repeats,
|
||||
warmup_iters=args.warmup_iters,
|
||||
profile_memory=args.profile_memory,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
use_cuda_graphs=args.cuda_graphs,
|
||||
)
|
||||
|
||||
# Add decode pipeline config
|
||||
@@ -839,6 +871,8 @@ def main():
|
||||
"repeats": args.repeats,
|
||||
"warmup_iters": args.warmup_iters,
|
||||
"profile_memory": args.profile_memory,
|
||||
"kv_cache_dtype": args.kv_cache_dtype,
|
||||
"use_cuda_graphs": args.cuda_graphs,
|
||||
}
|
||||
all_results = run_model_parameter_sweep(
|
||||
backends,
|
||||
@@ -861,6 +895,8 @@ def main():
|
||||
"repeats": args.repeats,
|
||||
"warmup_iters": args.warmup_iters,
|
||||
"profile_memory": args.profile_memory,
|
||||
"kv_cache_dtype": args.kv_cache_dtype,
|
||||
"use_cuda_graphs": args.cuda_graphs,
|
||||
}
|
||||
all_results = run_parameter_sweep(
|
||||
backends, args.batch_specs, base_config_args, args.parameter_sweep, console
|
||||
@@ -891,6 +927,8 @@ def main():
|
||||
repeats=args.repeats,
|
||||
warmup_iters=args.warmup_iters,
|
||||
profile_memory=args.profile_memory,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
use_cuda_graphs=args.cuda_graphs,
|
||||
)
|
||||
|
||||
result = run_benchmark(config)
|
||||
|
||||
@@ -213,6 +213,9 @@ class BenchmarkConfig:
|
||||
profile_memory: bool = False
|
||||
use_cuda_graphs: bool = False
|
||||
|
||||
# "auto" or "fp8"
|
||||
kv_cache_dtype: str = "auto"
|
||||
|
||||
# MLA-specific
|
||||
prefill_backend: str | None = None
|
||||
kv_lora_rank: int | None = None
|
||||
@@ -369,6 +372,7 @@ class ResultsFormatter:
|
||||
"backend",
|
||||
"batch_spec",
|
||||
"num_layers",
|
||||
"kv_cache_dtype",
|
||||
"mean_time",
|
||||
"std_time",
|
||||
"throughput",
|
||||
@@ -382,6 +386,7 @@ class ResultsFormatter:
|
||||
"backend": r.config.backend,
|
||||
"batch_spec": r.config.batch_spec,
|
||||
"num_layers": r.config.num_layers,
|
||||
"kv_cache_dtype": r.config.kv_cache_dtype,
|
||||
"mean_time": r.mean_time,
|
||||
"std_time": r.std_time,
|
||||
"throughput": r.throughput_tokens_per_sec or 0,
|
||||
|
||||
@@ -30,9 +30,9 @@ batch_specs:
|
||||
- "2q16k_32q1s4k" # 2 very large prefill + 32 decode
|
||||
|
||||
# Context extension + decode
|
||||
- "2q1kkv2k_16q1s1k" # 2 extend + 16 decode
|
||||
- "4q2kkv4k_32q1s2k" # 4 extend + 32 decode
|
||||
- "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode
|
||||
- "2q1ks2k_16q1s1k" # 2 extend + 16 decode
|
||||
- "4q2ks4k_32q1s2k" # 4 extend + 32 decode
|
||||
- "2q1ks8k_32q1s2k" # 2 large extend + 32 decode
|
||||
|
||||
# Explicitly chunked prefill
|
||||
- "q8k" # 8k prefill with chunking hint
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
# MLA decode-only benchmark configuration
|
||||
|
||||
model:
|
||||
name: "deepseek-v3"
|
||||
num_layers: 60
|
||||
num_q_heads: 128 # Base value, can be swept for TP simulation
|
||||
num_kv_heads: 1 # MLA uses single latent KV
|
||||
head_dim: 576
|
||||
kv_lora_rank: 512
|
||||
qk_nope_head_dim: 128
|
||||
qk_rope_head_dim: 64
|
||||
v_head_dim: 128
|
||||
block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128
|
||||
|
||||
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
|
||||
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
|
||||
model_parameter_sweep:
|
||||
param_name: "num_q_heads"
|
||||
values: [128, 64, 32, 16]
|
||||
label_format: "{backend}_{value}h"
|
||||
|
||||
batch_specs:
|
||||
# Small batches, varying sequence lengths
|
||||
- "16q1s512" # 16 requests, 512 KV cache
|
||||
- "16q1s1k" # 16 requests, 1k KV cache
|
||||
- "16q1s2k" # 16 requests, 2k KV cache
|
||||
- "16q1s4k" # 16 requests, 4k KV cache
|
||||
|
||||
# Medium batches
|
||||
- "32q1s1k" # 32 requests, 1k KV cache
|
||||
- "32q1s2k" # 32 requests, 2k KV cache
|
||||
- "32q1s4k" # 32 requests, 4k KV cache
|
||||
- "32q1s8k" # 32 requests, 8k KV cache
|
||||
|
||||
# Large batches
|
||||
- "64q1s1k" # 64 requests, 1k KV cache
|
||||
- "64q1s2k" # 64 requests, 2k KV cache
|
||||
- "64q1s4k" # 64 requests, 4k KV cache
|
||||
- "64q1s8k" # 64 requests, 8k KV cache
|
||||
|
||||
# Very large batches
|
||||
- "128q1s1k" # 128 requests, 1k KV cache
|
||||
- "128q1s2k" # 128 requests, 2k KV cache
|
||||
- "128q1s4k" # 128 requests, 4k KV cache
|
||||
- "128q1s8k" # 128 requests, 8k KV cache
|
||||
|
||||
# Long context
|
||||
- "32q1s16k" # 32 requests, 16k KV cache
|
||||
- "32q1s32k" # 32 requests, 32k KV cache
|
||||
|
||||
backends:
|
||||
- FLASHMLA_SPARSE
|
||||
- FLASHINFER_MLA_SPARSE
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 100
|
||||
warmup_iters: 10
|
||||
profile_memory: true
|
||||
@@ -60,9 +60,11 @@ def create_minimal_vllm_config(
|
||||
model_name: str = "deepseek-v3",
|
||||
block_size: int = 128,
|
||||
max_num_seqs: int = 256,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
mla_dims: dict | None = None,
|
||||
index_topk: int | None = None,
|
||||
prefill_backend: str | None = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create minimal VllmConfig for MLA benchmarks.
|
||||
@@ -149,13 +151,13 @@ def create_minimal_vllm_config(
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
cache_dtype="auto",
|
||||
cache_dtype=kv_cache_dtype,
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=8192,
|
||||
max_num_batched_tokens=max(max_num_batched_tokens, max_num_seqs),
|
||||
max_model_len=32768,
|
||||
is_encoder_decoder=False,
|
||||
enable_chunked_prefill=True,
|
||||
@@ -535,6 +537,7 @@ def _create_backend_impl(
|
||||
device: torch.device,
|
||||
max_num_tokens: int = 8192,
|
||||
index_topk: int | None = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
):
|
||||
"""
|
||||
Create backend implementation instance.
|
||||
@@ -583,7 +586,7 @@ def _create_backend_impl(
|
||||
"num_kv_heads": mla_dims["num_kv_heads"],
|
||||
"alibi_slopes": None,
|
||||
"sliding_window": None,
|
||||
"kv_cache_dtype": "auto",
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"logits_soft_cap": None,
|
||||
"attn_type": "decoder",
|
||||
"kv_sharing_target_layer_name": None,
|
||||
@@ -701,6 +704,7 @@ def _run_single_benchmark(
|
||||
mla_dims: dict,
|
||||
device: torch.device,
|
||||
indexer=None,
|
||||
kv_cache_dtype: str | None = None,
|
||||
) -> BenchmarkResult:
|
||||
"""
|
||||
Run a single benchmark iteration.
|
||||
@@ -734,49 +738,124 @@ def _run_single_benchmark(
|
||||
)
|
||||
|
||||
# Create KV cache
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
if kv_cache_dtype is None:
|
||||
kv_cache_dtype = getattr(config, "kv_cache_dtype", "auto")
|
||||
head_size = mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"]
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# FlashMLA sparse custom format: 656 bytes per token, stored as uint8.
|
||||
# Layout: kv_lora_rank fp8 bytes + 4 float32 tile scales
|
||||
# + 2*rope_dim bf16 bytes
|
||||
# = 512 + 16 + 128 = 656 bytes for DeepSeek dims.
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
656,
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
elif kv_cache_dtype == "fp8":
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Create input tensors for both decode and prefill modes
|
||||
decode_inputs, prefill_inputs = _create_input_tensors(
|
||||
total_q,
|
||||
mla_dims,
|
||||
backend_cfg["query_format"],
|
||||
device,
|
||||
torch.bfloat16,
|
||||
)
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
).view(current_platform.fp8_dtype())
|
||||
else:
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Fill indexer with random indices for sparse backends
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
if is_sparse and indexer is not None:
|
||||
indexer.fill_random_indices(total_q, max_kv_len)
|
||||
|
||||
# Determine which forward method to use based on metadata
|
||||
if metadata.decode is not None:
|
||||
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
|
||||
elif metadata.prefill is not None:
|
||||
forward_fn = lambda: impl.forward_mha(
|
||||
prefill_inputs["q"],
|
||||
prefill_inputs["k_c_normed"],
|
||||
prefill_inputs["k_pe"],
|
||||
kv_cache,
|
||||
metadata,
|
||||
prefill_inputs["k_scale"],
|
||||
prefill_inputs["output"],
|
||||
)
|
||||
else:
|
||||
# Determine which forward methods to use based on metadata.
|
||||
# Sparse MLA backends always use forward_mqa
|
||||
has_decode = is_sparse or getattr(metadata, "decode", None) is not None
|
||||
has_prefill = not is_sparse and getattr(metadata, "prefill", None) is not None
|
||||
if not has_decode and not has_prefill:
|
||||
raise RuntimeError("Metadata has neither decode nor prefill metadata")
|
||||
|
||||
num_decode = (
|
||||
metadata.num_decode_tokens
|
||||
if (has_decode and has_prefill)
|
||||
else total_q
|
||||
if has_decode
|
||||
else 0
|
||||
)
|
||||
num_prefill = total_q - num_decode
|
||||
|
||||
# Some backends requires fp8 queries when using fp8 KV cache.
|
||||
is_fp8_kvcache = kv_cache_dtype.startswith("fp8")
|
||||
quantize_query = is_fp8_kvcache and getattr(
|
||||
impl, "supports_quant_query_input", False
|
||||
)
|
||||
|
||||
# quantize_query forces concat format
|
||||
query_fmt = "concat" if quantize_query else backend_cfg["query_format"]
|
||||
|
||||
# Create decode query tensors
|
||||
if has_decode:
|
||||
decode_inputs, _ = _create_input_tensors(
|
||||
num_decode, mla_dims, query_fmt, device, torch.bfloat16
|
||||
)
|
||||
# Cast decode query to fp8 if the backend supports it
|
||||
if quantize_query:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if isinstance(decode_inputs, tuple):
|
||||
decode_inputs = torch.cat(list(decode_inputs), dim=-1)
|
||||
decode_inputs = decode_inputs.to(current_platform.fp8_dtype())
|
||||
|
||||
# Create prefill input tensors
|
||||
if has_prefill:
|
||||
_, prefill_inputs = _create_input_tensors(
|
||||
num_prefill, mla_dims, query_fmt, device, torch.bfloat16
|
||||
)
|
||||
|
||||
# Build forward function
|
||||
def forward_fn():
|
||||
results = []
|
||||
if has_decode:
|
||||
results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer))
|
||||
if has_prefill:
|
||||
results.append(
|
||||
impl.forward_mha(
|
||||
prefill_inputs["q"],
|
||||
prefill_inputs["k_c_normed"],
|
||||
prefill_inputs["k_pe"],
|
||||
kv_cache,
|
||||
metadata,
|
||||
prefill_inputs["k_scale"],
|
||||
prefill_inputs["output"],
|
||||
)
|
||||
)
|
||||
return results[0] if len(results) == 1 else tuple(results)
|
||||
|
||||
# Warmup
|
||||
for _ in range(config.warmup_iters):
|
||||
forward_fn()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
# Optionally capture a CUDA graph after warmup.
|
||||
# Graph replay eliminates CPU launch overhead so timings reflect pure
|
||||
# kernel time.
|
||||
if config.use_cuda_graphs:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
forward_fn()
|
||||
benchmark_fn = graph.replay
|
||||
else:
|
||||
benchmark_fn = forward_fn
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(config.repeats):
|
||||
@@ -785,7 +864,7 @@ def _run_single_benchmark(
|
||||
|
||||
start.record()
|
||||
for _ in range(config.num_layers):
|
||||
forward_fn()
|
||||
benchmark_fn()
|
||||
end.record()
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
@@ -852,13 +931,30 @@ def _run_mla_benchmark_batched(
|
||||
# Determine if this is a sparse backend
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
|
||||
# Extract kv_cache_dtype from the first config
|
||||
kv_cache_dtype = getattr(first_config, "kv_cache_dtype", "auto")
|
||||
|
||||
# FlashMLA sparse only supports "fp8_ds_mla" internally (not generic "fp8").
|
||||
# Remap here so the user can pass --kv-cache-dtype fp8 regardless of backend.
|
||||
if backend.upper() == "FLASHMLA_SPARSE" and kv_cache_dtype == "fp8":
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
|
||||
# Compute max total_q across all configs so the metadata builder buffer
|
||||
# and scheduler config are large enough for all batch specs.
|
||||
max_total_q = max(
|
||||
sum(r.q_len for r in parse_batch_spec(cfg.batch_spec))
|
||||
for cfg, *_ in configs_with_params
|
||||
)
|
||||
|
||||
# Create and set vLLM config for MLA (reused across all benchmarks)
|
||||
vllm_config = create_minimal_vllm_config(
|
||||
model_name="deepseek-v3", # Used only for model path
|
||||
block_size=block_size,
|
||||
max_num_batched_tokens=max_total_q,
|
||||
mla_dims=mla_dims, # Use custom dims from config or default
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
prefill_backend=prefill_backend,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
results = []
|
||||
@@ -883,7 +979,9 @@ def _run_mla_benchmark_batched(
|
||||
mla_dims,
|
||||
vllm_config,
|
||||
device,
|
||||
max_num_tokens=max_total_q,
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
# Verify the actual prefill backend matches what was requested
|
||||
@@ -942,6 +1040,7 @@ def _run_mla_benchmark_batched(
|
||||
mla_dims,
|
||||
device,
|
||||
indexer=indexer,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ def _create_vllm_config(
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=config.block_size,
|
||||
cache_dtype="auto",
|
||||
cache_dtype=config.kv_cache_dtype,
|
||||
)
|
||||
cache_config.num_gpu_blocks = max_num_blocks
|
||||
cache_config.num_cpu_blocks = 0
|
||||
@@ -215,7 +215,7 @@ def _create_backend_impl(
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
kv_cache_dtype=config.kv_cache_dtype,
|
||||
)
|
||||
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
@@ -288,12 +288,22 @@ def _create_input_tensors(
|
||||
total_q: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
quantize_query: bool = False,
|
||||
) -> tuple:
|
||||
"""Create Q, K, V input tensors for all layers."""
|
||||
"""Create Q, K, V input tensors for all layers.
|
||||
|
||||
When quantize_query is True, queries are cast to fp8 to match backends
|
||||
that require query/key/value dtype consistency.
|
||||
"""
|
||||
q_dtype = dtype
|
||||
if quantize_query:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
q_dtype = current_platform.fp8_dtype()
|
||||
q_list = [
|
||||
torch.randn(
|
||||
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
|
||||
)
|
||||
).to(q_dtype)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
k_list = [
|
||||
@@ -344,10 +354,17 @@ def _create_kv_cache(
|
||||
# Compute inverse permutation to get back to logical view
|
||||
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
|
||||
|
||||
# Use fp8 dtype for cache when requested.
|
||||
cache_dtype = dtype
|
||||
if config.kv_cache_dtype == "fp8":
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
cache_dtype = current_platform.fp8_dtype()
|
||||
|
||||
cache_list = []
|
||||
for _ in range(config.num_layers):
|
||||
# Allocate in physical layout order (contiguous in memory)
|
||||
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
|
||||
cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
|
||||
# Permute to logical view
|
||||
cache = cache.permute(*inv_order)
|
||||
cache_list.append(cache)
|
||||
@@ -392,6 +409,37 @@ def _run_single_benchmark(
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
# Optionally capture a CUDA graph after warmup.
|
||||
# Graph replay eliminates CPU launch overhead so timings reflect pure
|
||||
# kernel time.
|
||||
if config.use_cuda_graphs:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for i in range(config.num_layers):
|
||||
impl.forward(
|
||||
layer,
|
||||
q_list[i],
|
||||
k_list[i],
|
||||
v_list[i],
|
||||
cache_list[i],
|
||||
attn_metadata,
|
||||
output=out,
|
||||
)
|
||||
benchmark_fn = graph.replay
|
||||
else:
|
||||
|
||||
def benchmark_fn():
|
||||
for i in range(config.num_layers):
|
||||
impl.forward(
|
||||
layer,
|
||||
q_list[i],
|
||||
k_list[i],
|
||||
v_list[i],
|
||||
cache_list[i],
|
||||
attn_metadata,
|
||||
output=out,
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(config.repeats):
|
||||
@@ -399,16 +447,7 @@ def _run_single_benchmark(
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for i in range(config.num_layers):
|
||||
impl.forward(
|
||||
layer,
|
||||
q_list[i],
|
||||
k_list[i],
|
||||
v_list[i],
|
||||
cache_list[i],
|
||||
attn_metadata,
|
||||
output=out,
|
||||
)
|
||||
benchmark_fn()
|
||||
end.record()
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
@@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
# Only quantize queries when the impl supports it
|
||||
quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr(
|
||||
impl, "supports_quant_query_input", False
|
||||
)
|
||||
q_list, k_list, v_list = _create_input_tensors(
|
||||
config, total_q, device, dtype
|
||||
config, total_q, device, dtype, quantize_query=quantize_query
|
||||
)
|
||||
|
||||
cache_list = _create_kv_cache(
|
||||
|
||||
12
csrc/ops.h
12
csrc/ops.h
@@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
|
||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||
torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||
torch::Tensor& output_scale);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_sf, torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||
torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||
torch::Tensor& output_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int64_t n = input.size(-1);
|
||||
int64_t m = input.numel() / n;
|
||||
auto device = input.device();
|
||||
|
||||
// Two fp4 values packed into a uint8
|
||||
auto output = torch::empty(
|
||||
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||
|
||||
torch::Tensor output_sf;
|
||||
if (is_sf_swizzled_layout) {
|
||||
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
||||
output_sf = torch::empty(
|
||||
{sf_m, sf_n},
|
||||
torch::TensorOptions().device(device).dtype(torch::kInt32));
|
||||
} else {
|
||||
output_sf = torch::empty(
|
||||
{m, n / CVT_FP4_SF_VEC_SIZE},
|
||||
torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||
}
|
||||
|
||||
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
||||
output_sf);
|
||||
return {output, output_sf};
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <utility>
|
||||
|
||||
#include "../../cuda_vec_utils.cuh"
|
||||
|
||||
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
|
||||
return round_up(m, ROW_TILE);
|
||||
}
|
||||
|
||||
// Compute the shape of the swizzled SF output tensor.
|
||||
// Returns (rounded_m, rounded_n / 4) where:
|
||||
// rounded_m = round_up(m, 128)
|
||||
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
|
||||
inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
|
||||
int64_t n) {
|
||||
int64_t rounded_m = round_up(m, static_cast<int64_t>(128));
|
||||
int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE;
|
||||
int64_t rounded_n = round_up(scale_n, static_cast<int64_t>(4));
|
||||
return {rounded_m, rounded_n / 4};
|
||||
}
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
|
||||
uint32_t val;
|
||||
|
||||
@@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
"Outer scale stride must be 1 when scales are not transposed");
|
||||
}
|
||||
|
||||
int64_t hidden_size = input.size(-1);
|
||||
TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0,
|
||||
"hidden_size must be a positive multiple of group_size");
|
||||
int64_t num_tokens = input.numel() / hidden_size;
|
||||
int64_t num_groups = hidden_size / group_size;
|
||||
TORCH_CHECK(scales.numel() >= num_tokens * num_groups,
|
||||
"scales buffer too small: need ", num_tokens * num_groups,
|
||||
" elements, got ", scales.numel());
|
||||
|
||||
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
|
||||
var_epsilon, scale_ub, residual,
|
||||
is_scale_transposed);
|
||||
|
||||
@@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
" Tensor! output_scale, Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> ()");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
"scaled_fp4_quant(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
|
||||
|
||||
// Out variant
|
||||
// TODO: Add {at::Tag::out_variant} tag and update all call sites
|
||||
// to use the functional variant once vLLM upgrades PyTorch.
|
||||
// See pytorch/pytorch#176117.
|
||||
ops.def(
|
||||
"scaled_fp4_quant.out(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||
"-> ()");
|
||||
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
ops.def(
|
||||
|
||||
@@ -107,6 +107,27 @@ vLLM supports the `tool_choice='none'` option in the chat completion API. When t
|
||||
!!! note
|
||||
When tools are specified in the request, vLLM includes tool definitions in the prompt by default, regardless of the `tool_choice` setting. To exclude tool definitions when `tool_choice='none'`, use the `--exclude-tools-when-tool-choice-none` option.
|
||||
|
||||
## Constrained Decoding Behavior
|
||||
|
||||
Whether vLLM enforces the tool parameter schema during generation depends on the `tool_choice` mode:
|
||||
|
||||
| `tool_choice` value | Schema-constrained decoding | Behavior |
|
||||
| --- | --- | --- |
|
||||
| Named function | Yes (via structured outputs backend) | Arguments are guaranteed to be valid JSON conforming to the function's parameter schema. |
|
||||
| `"required"` | Yes (via structured outputs backend) | Same as named function. The model must produce at least one tool call. |
|
||||
| `"auto"` | No | The model generates freely. A tool-call parser extracts tool calls from the raw text. Arguments may be malformed or not match the schema. |
|
||||
| `"none"` | N/A | No tool calls are produced. |
|
||||
|
||||
When schema conformance matters, prefer `tool_choice="required"` or named function calling over `"auto"`.
|
||||
|
||||
### Strict Mode (`strict` parameter)
|
||||
|
||||
The [OpenAI API](https://platform.openai.com/docs/guides/function-calling#strict-mode) supports a `strict` field on function definitions. When set to `true`, OpenAI uses constrained decoding to guarantee that tool-call arguments match the function schema, even in `tool_choice="auto"` mode.
|
||||
|
||||
vLLM **does not implement** `strict` mode today. The `strict` field is accepted in requests (to avoid breaking clients that set it), but it has no effect on decoding behavior. In auto mode, argument validity depends entirely on the model's output quality and the parser's extraction logic.
|
||||
|
||||
Tracking issues: [#15526](https://github.com/vllm-project/vllm/issues/15526), [#16313](https://github.com/vllm-project/vllm/issues/16313).
|
||||
|
||||
## Automatic Function Calling
|
||||
|
||||
To enable this feature, you should set the following flags:
|
||||
@@ -124,6 +145,9 @@ from HuggingFace; and you can find an example of this in a `tokenizer_config.jso
|
||||
|
||||
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
|
||||
|
||||
!!! note
|
||||
With `tool_choice="auto"`, tool-call arguments are extracted from the model's raw text output by the selected parser. No schema-level constraint is applied during decoding, so arguments may occasionally be malformed or violate the function's parameter schema. See [Constrained Decoding Behavior](#constrained-decoding-behavior) for details.
|
||||
|
||||
### Hermes Models (`hermes`)
|
||||
|
||||
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.
|
||||
|
||||
@@ -72,6 +72,9 @@ In addition, we have the following custom APIs:
|
||||
- Only applicable to [classification models](../models/pooling_models.md).
|
||||
- [Score API](#score-api) (`/score`)
|
||||
- Applicable to [embedding models and cross-encoder models](../models/pooling_models.md).
|
||||
- [Cohere Embed API](#cohere-embed-api) (`/v2/embed`)
|
||||
- Compatible with [Cohere's Embed API](https://docs.cohere.com/reference/embed)
|
||||
- Works with any [embedding model](../models/pooling_models.md), including multimodal models.
|
||||
- [Re-rank API](#re-rank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
|
||||
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
|
||||
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
|
||||
@@ -429,6 +432,137 @@ these extra parameters are supported instead:
|
||||
--8<-- "vllm/entrypoints/pooling/base/protocol.py:embed-extra-params"
|
||||
```
|
||||
|
||||
### Cohere Embed API
|
||||
|
||||
Our API is also compatible with [Cohere's Embed v2 API](https://docs.cohere.com/reference/embed) which adds support for some modern embedding feature such as truncation, output dimensions, embedding types, and input types. This endpoint works with any embedding model (including multimodal models).
|
||||
|
||||
#### Cohere Embed API request parameters
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `model` | string | Yes | Model name |
|
||||
| `input_type` | string | No | Prompt prefix key (model-dependent, see below) |
|
||||
| `texts` | list[string] | No | Text inputs (use one of `texts`, `images`, or `inputs`) |
|
||||
| `images` | list[string] | No | Base64 data URI images |
|
||||
| `inputs` | list[object] | No | Mixed text and image content objects |
|
||||
| `embedding_types` | list[string] | No | Output types (default: `["float"]`) |
|
||||
| `output_dimension` | int | No | Truncate embeddings to this dimension (Matryoshka) |
|
||||
| `truncate` | string | No | `END`, `START`, or `NONE` (default: `END`) |
|
||||
|
||||
#### Text embedding
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v2/embed" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
"input_type": "query",
|
||||
"texts": ["Hello world", "How are you?"],
|
||||
"embedding_types": ["float"]
|
||||
}'
|
||||
```
|
||||
|
||||
??? console "Response"
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "embd-...",
|
||||
"embeddings": {
|
||||
"float": [
|
||||
[0.012, -0.034, ...],
|
||||
[0.056, 0.078, ...]
|
||||
]
|
||||
},
|
||||
"texts": ["Hello world", "How are you?"],
|
||||
"meta": {
|
||||
"api_version": {"version": "2"},
|
||||
"billed_units": {"input_tokens": 12}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Mixed text and image inputs
|
||||
|
||||
For multimodal models, you can embed images by passing base64 data URIs. The `inputs` field accepts a list of objects with mixed text and image content:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v2/embed" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "google/siglip-so400m-patch14-384",
|
||||
"inputs": [
|
||||
{
|
||||
"content": [
|
||||
{"type": "text", "text": "A photo of a cat"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}
|
||||
]
|
||||
}
|
||||
],
|
||||
"embedding_types": ["float"]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Embedding types
|
||||
|
||||
The `embedding_types` parameter controls the output format. Multiple types can be requested in a single call:
|
||||
|
||||
| Type | Description |
|
||||
| ---- | ----------- |
|
||||
| `float` | Raw float32 embeddings (default) |
|
||||
| `binary` | Bit-packed signed binary |
|
||||
| `ubinary` | Bit-packed unsigned binary |
|
||||
| `base64` | Little-endian float32 encoded as base64 |
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v2/embed" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
"input_type": "query",
|
||||
"texts": ["What is machine learning?"],
|
||||
"embedding_types": ["float", "binary"]
|
||||
}'
|
||||
```
|
||||
|
||||
??? console "Response"
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "embd-...",
|
||||
"embeddings": {
|
||||
"float": [[0.012, -0.034, ...]],
|
||||
"binary": [[42, -117, ...]]
|
||||
},
|
||||
"texts": ["What is machine learning?"],
|
||||
"meta": {
|
||||
"api_version": {"version": "2"},
|
||||
"billed_units": {"input_tokens": 8}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Truncation
|
||||
|
||||
The `truncate` parameter controls how inputs exceeding the model's maximum sequence length are handled:
|
||||
|
||||
| Value | Behavior |
|
||||
| ----- | --------- |
|
||||
| `END` (default) | Keep the first tokens, drop the end |
|
||||
| `START` | Keep the last tokens, drop the beginning |
|
||||
| `NONE` | Return an error if the input is too long |
|
||||
|
||||
#### Input type and prompt prefixes
|
||||
|
||||
The `input_type` field selects a prompt prefix to prepend to each text input. The available values
|
||||
depend on the model:
|
||||
|
||||
- **Models with `task_instructions` in `config.json`**: The keys from the `task_instructions` dict are
|
||||
the valid `input_type` values and the corresponding value is prepended to each text.
|
||||
- **Models with `config_sentence_transformers.json` prompts**: The keys from the `prompts` dict are
|
||||
the valid `input_type` values. For example, `Snowflake/snowflake-arctic-embed-xs` defines `"query"`,
|
||||
so setting `input_type: "query"` prepends `"Represent this sentence for searching relevant passages: "`.
|
||||
- **Other models**: `input_type` is not accepted and will raise a validation error if passed.
|
||||
|
||||
### Transcriptions API
|
||||
|
||||
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
|
||||
|
||||
@@ -50,7 +50,7 @@ av==16.1.0
|
||||
blobfile==3.0.0
|
||||
# Multi-Modal Models Test
|
||||
decord==0.6.0
|
||||
# video processing, required by entrypoints/openai/test_video.py
|
||||
# video processing, required by entrypoints/openai/chat_completion/test_video.py
|
||||
rapidfuzz==3.12.1
|
||||
|
||||
# OpenAI compatibility and testing
|
||||
|
||||
@@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.scaled_fp4_quant.default,
|
||||
torch.ops._C.scaled_fp4_quant.out,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server
|
||||
from tests.entrypoints.openai.chat_completion.test_oot_registration import (
|
||||
run_and_test_dummy_opt_api_server,
|
||||
)
|
||||
|
||||
|
||||
def test_distributed_oot(dummy_opt_path: str):
|
||||
|
||||
@@ -4,12 +4,11 @@ import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ..openai.test_vision import TEST_IMAGE_ASSETS
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def text_llm():
|
||||
|
||||
@@ -6,13 +6,12 @@ import logging
|
||||
import pytest
|
||||
import regex as re
|
||||
|
||||
from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS
|
||||
from vllm import LLM
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.v1.metrics import loggers as stat_loggers
|
||||
from vllm.v1.metrics.reader import Counter, Metric
|
||||
|
||||
from ..openai.test_vision import TEST_IMAGE_ASSETS
|
||||
|
||||
|
||||
def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]:
|
||||
return [
|
||||
|
||||
@@ -7,11 +7,10 @@ import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.multimodal.utils import encode_audio_base64, encode_audio_url, fetch_audio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
TEST_AUDIO_URLS = [
|
||||
AudioAsset("winning_call").url,
|
||||
@@ -8,8 +8,8 @@ import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...conftest import VideoTestAssets
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.conftest import VideoTestAssets
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-Omni-3B"
|
||||
|
||||
@@ -8,8 +8,8 @@ import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from ...conftest import AudioTestAssets
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.conftest import AudioTestAssets
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
# NOTE - the tests in this module are currently analogous to test_chat, but are
|
||||
# separated to avoid OOM killing due to module-scoped servers, since we
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert chatml_jinja_path.exists()
|
||||
@@ -8,7 +8,7 @@ from typing import Any, NamedTuple
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
# # any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||
@@ -7,11 +7,10 @@ import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.multimodal.utils import encode_video_url, fetch_video
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
|
||||
MAXIMUM_VIDEOS = 3
|
||||
|
||||
@@ -8,12 +8,11 @@ import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
|
||||
from vllm.multimodal.media import MediaWithBytes
|
||||
from vllm.multimodal.utils import encode_image_url, fetch_image
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "microsoft/Phi-3.5-vision-instruct"
|
||||
MAXIMUM_IMAGES = 2
|
||||
|
||||
@@ -8,10 +8,9 @@ import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||
0
tests/entrypoints/openai/completion/__init__.py
Normal file
0
tests/entrypoints/openai/completion/__init__.py
Normal file
@@ -14,7 +14,7 @@ import torch
|
||||
from openai import BadRequestError
|
||||
from transformers import AutoConfig
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
@@ -11,11 +11,10 @@ import pytest
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.renderers.embed_utils import safe_load_prompt_embeds
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt():
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
import torch.cuda
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig,
|
||||
@@ -17,8 +18,6 @@ from vllm.model_executor.model_loader.tensorizer import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "unsloth/llama-3.2-1b-Instruct"
|
||||
LORA_PATH = "davzoku/finqa_adapter_1b"
|
||||
|
||||
@@ -6,11 +6,10 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b")
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
|
||||
310
tests/entrypoints/pooling/embed/test_cohere_online.py
Normal file
310
tests/entrypoints/pooling/embed/test_cohere_online.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Cohere /v2/embed API with generic (non-Cohere) models.
|
||||
|
||||
Validates that the Cohere v2 embed endpoint works correctly with standard
|
||||
embedding models, covering text embedding, embedding type conversions,
|
||||
response structure, batching, normalisation, and semantic similarity.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
MODELS: list[tuple[str, list[str]]] = [
|
||||
("intfloat/multilingual-e5-small", []),
|
||||
(
|
||||
"Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
[
|
||||
"--trust_remote_code",
|
||||
"--hf_overrides",
|
||||
'{"matryoshka_dimensions":[256]}',
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=MODELS, ids=lambda m: m[0])
|
||||
def model_config(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_name(model_config):
|
||||
return model_config[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_config):
|
||||
name, extra_args = model_config
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--gpu-memory-utilization",
|
||||
"0.02",
|
||||
] + extra_args
|
||||
with RemoteOpenAIServer(name, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
texts: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
input_type: str | None = None,
|
||||
embedding_types: list[str] | None = None,
|
||||
) -> dict:
|
||||
body: dict = {"model": model_name}
|
||||
if input_type is not None:
|
||||
body["input_type"] = input_type
|
||||
if texts is not None:
|
||||
body["texts"] = texts
|
||||
if images is not None:
|
||||
body["images"] = images
|
||||
if embedding_types is not None:
|
||||
body["embedding_types"] = embedding_types
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _openai_embed(
|
||||
server: RemoteOpenAIServer, model_name: str, texts: list[str]
|
||||
) -> dict:
|
||||
body = {"model": model_name, "input": texts, "encoding_format": "float"}
|
||||
resp = requests.post(server.url_for("/v1/embeddings"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
va, vb = np.array(a), np.array(b)
|
||||
return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Text embedding tests
|
||||
# -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_basic_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server, model_name, texts=["hello world"], embedding_types=["float"]
|
||||
)
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
|
||||
|
||||
def test_unsupported_input_type_rejected(server: RemoteOpenAIServer, model_name: str):
|
||||
"""An input_type not defined in the model's prompt config should be
|
||||
rejected with a 400 error."""
|
||||
body = {
|
||||
"model": model_name,
|
||||
"input_type": "nonexistent_type",
|
||||
"texts": ["hello world"],
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 400
|
||||
assert "Unsupported input_type" in resp.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_omitted_input_type_accepted(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Omitting input_type should always work (no prompt prefix applied)."""
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": ["hello world"],
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_v1_v2_parity(server: RemoteOpenAIServer, model_name: str):
|
||||
"""v1 (OpenAI) and v2 (Cohere) endpoints should produce the same
|
||||
float embeddings for a generic model."""
|
||||
texts = ["hello world"]
|
||||
v2 = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"])
|
||||
v1 = _openai_embed(server, model_name, texts)
|
||||
cos = _cosine_sim(v2["embeddings"]["float"][0], v1["data"][0]["embedding"])
|
||||
assert cos > 0.9999, f"v1/v2 parity failed, cosine={cos}"
|
||||
|
||||
|
||||
def test_embedding_types(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["test"],
|
||||
embedding_types=["float", "binary", "ubinary"],
|
||||
)
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
assert len(r["embeddings"]["binary"][0]) == dim // 8
|
||||
assert len(r["embeddings"]["ubinary"][0]) == dim // 8
|
||||
|
||||
|
||||
def test_response_structure(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(server, model_name, texts=["test"], embedding_types=["float"])
|
||||
assert "id" in r
|
||||
assert "embeddings" in r
|
||||
assert "texts" in r
|
||||
assert r["texts"] == ["test"]
|
||||
assert "meta" in r
|
||||
assert r["meta"]["api_version"]["version"] == "2"
|
||||
assert "billed_units" in r["meta"]
|
||||
assert r["meta"]["billed_units"]["input_tokens"] > 0
|
||||
assert r["meta"]["billed_units"]["image_tokens"] == 0
|
||||
|
||||
|
||||
def test_batch(server: RemoteOpenAIServer, model_name: str):
|
||||
texts = ["apple", "banana", "cherry"]
|
||||
r = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"])
|
||||
assert len(r["embeddings"]["float"]) == 3
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
for emb in r["embeddings"]["float"]:
|
||||
assert len(emb) == dim
|
||||
|
||||
|
||||
def test_l2_normalized(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server, model_name, texts=["hello world"], embedding_types=["float"]
|
||||
)
|
||||
emb = np.array(r["embeddings"]["float"][0])
|
||||
assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01
|
||||
|
||||
|
||||
def test_semantic_similarity(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["machine learning", "deep learning", "chocolate cake recipe"],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
embs = r["embeddings"]["float"]
|
||||
cos_related = _cosine_sim(embs[0], embs[1])
|
||||
cos_unrelated = _cosine_sim(embs[0], embs[2])
|
||||
assert cos_related > cos_unrelated
|
||||
|
||||
|
||||
def test_missing_input_returns_error(server: RemoteOpenAIServer, model_name: str):
|
||||
body = {"model": model_name}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_base64_embedding_type(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["test encoding"],
|
||||
embedding_types=["float", "base64"],
|
||||
)
|
||||
float_emb = r["embeddings"]["float"][0]
|
||||
b64_str = r["embeddings"]["base64"][0]
|
||||
decoded = struct.unpack(f"<{len(float_emb)}f", base64.b64decode(b64_str))
|
||||
np.testing.assert_allclose(float_emb, decoded, rtol=1e-5)
|
||||
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Truncation tests
|
||||
# -----------------------------------------------------------
|
||||
|
||||
|
||||
def _cohere_embed_raw(
|
||||
server: RemoteOpenAIServer,
|
||||
body: dict,
|
||||
) -> requests.Response:
|
||||
return requests.post(server.url_for("/v2/embed"), json=body)
|
||||
|
||||
|
||||
def test_truncate_end_succeeds(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=END should silently truncate long input."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "END",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_truncate_start_succeeds(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=START should silently truncate long input from the start."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "START",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_truncate_none_rejects_long_input(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=NONE should error when input exceeds model context."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "NONE",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_truncate_start_vs_end_differ(server: RemoteOpenAIServer, model_name: str):
|
||||
"""START and END truncation should produce different embeddings
|
||||
when the input is long enough to actually be truncated.
|
||||
|
||||
We construct input with distinct tokens at the start vs end
|
||||
so that keeping different halves produces different embeddings.
|
||||
"""
|
||||
start_words = " ".join([f"alpha{i}" for i in range(300)])
|
||||
end_words = " ".join([f"omega{i}" for i in range(300)])
|
||||
long_text = start_words + " " + end_words
|
||||
|
||||
body_end = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "END",
|
||||
}
|
||||
body_start = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "START",
|
||||
}
|
||||
r_end = _cohere_embed_raw(server, body_end).json()
|
||||
r_start = _cohere_embed_raw(server, body_start).json()
|
||||
|
||||
emb_end = r_end["embeddings"]["float"][0]
|
||||
emb_start = r_start["embeddings"]["float"][0]
|
||||
cos = _cosine_sim(emb_end, emb_start)
|
||||
assert cos < 0.99, (
|
||||
f"START and END truncation should produce different embeddings "
|
||||
f"for long input, but cosine similarity was {cos}"
|
||||
)
|
||||
135
tests/entrypoints/pooling/embed/test_cohere_online_vision.py
Normal file
135
tests/entrypoints/pooling/embed/test_cohere_online_vision.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Cohere /v2/embed API with a multimodal model (SigLIP).
|
||||
|
||||
Validates image embedding, batching, normalisation, and embedding type
|
||||
conversions through the /v2/embed endpoint.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
import zlib
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "google/siglip-so400m-patch14-384"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"64",
|
||||
"--gpu-memory-utilization",
|
||||
"0.3",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _make_tiny_png(r: int, g: int, b: int, w: int = 2, h: int = 2) -> str:
|
||||
raw = b""
|
||||
for _ in range(h):
|
||||
raw += b"\x00" + bytes([r, g, b]) * w
|
||||
compressed = zlib.compress(raw)
|
||||
|
||||
def chunk(ctype: bytes, cdata: bytes) -> bytes:
|
||||
c = ctype + cdata
|
||||
return (
|
||||
struct.pack(">I", len(cdata))
|
||||
+ c
|
||||
+ struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF)
|
||||
)
|
||||
|
||||
ihdr = struct.pack(">IIBBBBB", w, h, 8, 2, 0, 0, 0)
|
||||
png = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
+ chunk(b"IHDR", ihdr)
|
||||
+ chunk(b"IDAT", compressed)
|
||||
+ chunk(b"IEND", b"")
|
||||
)
|
||||
return "data:image/png;base64," + base64.b64encode(png).decode()
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
embedding_types: list[str] | None = None,
|
||||
) -> dict:
|
||||
body: dict = {"model": MODEL_NAME}
|
||||
if texts is not None:
|
||||
body["texts"] = texts
|
||||
if images is not None:
|
||||
body["images"] = images
|
||||
if embedding_types is not None:
|
||||
body["embedding_types"] = embedding_types
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def test_image_embed(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(255, 0, 0)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
assert r["meta"]["billed_units"]["image_tokens"] > 0
|
||||
assert r["meta"]["billed_units"]["input_tokens"] == 0
|
||||
|
||||
|
||||
def test_image_batch(server: RemoteOpenAIServer):
|
||||
red = _make_tiny_png(255, 0, 0)
|
||||
blue = _make_tiny_png(0, 0, 255)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[red, blue],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
assert len(r["embeddings"]["float"]) == 2
|
||||
|
||||
|
||||
def test_image_l2_normalized(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(0, 255, 0)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
emb = np.array(r["embeddings"]["float"][0])
|
||||
assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01
|
||||
|
||||
|
||||
def test_image_embedding_types(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(128, 128, 128)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float", "binary", "ubinary"],
|
||||
)
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
assert len(r["embeddings"]["binary"][0]) == dim // 8
|
||||
assert len(r["embeddings"]["ubinary"][0]) == dim // 8
|
||||
|
||||
|
||||
def test_text_embed_on_multimodal(server: RemoteOpenAIServer):
|
||||
"""SigLIP also supports text-only embedding via /v2/embed."""
|
||||
r = _cohere_embed(server, texts=["hello world"], embedding_types=["float"])
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
102
tests/entrypoints/pooling/embed/test_cohere_openai_parity.py
Normal file
102
tests/entrypoints/pooling/embed/test_cohere_openai_parity.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Parity test between Cohere /v2/embed and OpenAI /v1/embeddings.
|
||||
|
||||
Verifies that both endpoints produce identical float embeddings when
|
||||
no prompt prefix is applied (input_type omitted for Cohere /v2/embed).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--gpu-memory-utilization",
|
||||
"0.02",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
body = {
|
||||
"model": MODEL_NAME,
|
||||
"texts": texts,
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["embeddings"]["float"]
|
||||
|
||||
|
||||
def _openai_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
body = {"model": MODEL_NAME, "input": texts, "encoding_format": "float"}
|
||||
resp = requests.post(server.url_for("/v1/embeddings"), json=body)
|
||||
resp.raise_for_status()
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
|
||||
def test_single_text_parity(server: RemoteOpenAIServer):
|
||||
"""A single text should produce identical embeddings via both APIs."""
|
||||
texts = ["the quick brown fox jumps over the lazy dog"]
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5)
|
||||
|
||||
|
||||
def test_batch_parity(server: RemoteOpenAIServer):
|
||||
"""A batch of texts should produce identical embeddings via both APIs,
|
||||
in the same order."""
|
||||
texts = [
|
||||
"machine learning",
|
||||
"deep learning",
|
||||
"natural language processing",
|
||||
]
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
assert len(v2) == len(v1) == 3
|
||||
for i in range(3):
|
||||
np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}")
|
||||
|
||||
|
||||
def test_token_count_parity(server: RemoteOpenAIServer):
|
||||
"""Both APIs should report the same prompt token count."""
|
||||
texts = ["hello world"]
|
||||
v2_resp = requests.post(
|
||||
server.url_for("/v2/embed"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"texts": texts,
|
||||
"embedding_types": ["float"],
|
||||
},
|
||||
)
|
||||
v1_resp = requests.post(
|
||||
server.url_for("/v1/embeddings"),
|
||||
json={"model": MODEL_NAME, "input": texts, "encoding_format": "float"},
|
||||
)
|
||||
v2_resp.raise_for_status()
|
||||
v1_resp.raise_for_status()
|
||||
v2_tokens = v2_resp.json()["meta"]["billed_units"]["input_tokens"]
|
||||
v1_tokens = v1_resp.json()["usage"]["prompt_tokens"]
|
||||
assert v2_tokens == v1_tokens
|
||||
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for EmbedIOProcessor."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveTruncation:
|
||||
"""Unit tests for EmbedIOProcessor._resolve_cohere_truncation."""
|
||||
|
||||
@staticmethod
|
||||
def _make_request(**kwargs) -> CohereEmbedRequest:
|
||||
defaults = {
|
||||
"model": "test",
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"],
|
||||
}
|
||||
return CohereEmbedRequest(**(defaults | kwargs))
|
||||
|
||||
def test_truncate_end_default(self):
|
||||
req = self._make_request()
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_explicit(self):
|
||||
req = self._make_request(truncate="END")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_with_max_tokens(self):
|
||||
req = self._make_request(truncate="END", max_tokens=128)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 128
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none(self):
|
||||
req = self._make_request(truncate="NONE")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none_with_max_tokens(self):
|
||||
"""truncate=NONE should NOT set truncate_prompt_tokens; the
|
||||
max_tokens limit is enforced separately via _check_max_tokens."""
|
||||
req = self._make_request(truncate="NONE", max_tokens=10)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_start(self):
|
||||
req = self._make_request(truncate="START")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side == "left"
|
||||
|
||||
def test_truncate_start_with_max_tokens(self):
|
||||
req = self._make_request(truncate="START", max_tokens=64)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 64
|
||||
assert side == "left"
|
||||
|
||||
|
||||
class TestApplyStPrompt:
|
||||
"""Unit tests for EmbedIOProcessor._apply_task_instruction."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_no_prompts_configured(self):
|
||||
handler = self._make_handler(None)
|
||||
texts = ["hello", "world"]
|
||||
assert handler._apply_task_instruction(texts, "query") is texts
|
||||
|
||||
def test_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
result = handler._apply_task_instruction(["hello"], "query")
|
||||
assert result == ["search_query: hello"]
|
||||
|
||||
def test_non_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "document") is texts
|
||||
|
||||
def test_multiple_texts(self):
|
||||
handler = self._make_handler(
|
||||
{"query": "Represent this sentence for searching: "}
|
||||
)
|
||||
result = handler._apply_task_instruction(["a", "b", "c"], "query")
|
||||
assert result == [
|
||||
"Represent this sentence for searching: a",
|
||||
"Represent this sentence for searching: b",
|
||||
"Represent this sentence for searching: c",
|
||||
]
|
||||
|
||||
def test_empty_prefix_returns_unchanged(self):
|
||||
handler = self._make_handler({"passage": ""})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "passage") is texts
|
||||
|
||||
|
||||
class TestLoadTaskInstructions:
|
||||
"""Unit tests for EmbedIOProcessor._load_task_instructions."""
|
||||
|
||||
def test_no_attribute(self):
|
||||
class FakeConfig:
|
||||
pass
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_with_task_instructions(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
result = EmbedIOProcessor._load_task_instructions(FakeConfig())
|
||||
assert result == {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
def test_empty_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {}
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_non_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = "not a dict"
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
|
||||
class TestCheckMaxTokens:
|
||||
"""Unit tests for EmbedIOProcessor._check_cohere_max_tokens."""
|
||||
|
||||
@staticmethod
|
||||
def _fake_output(n_tokens: int):
|
||||
class _Out:
|
||||
def __init__(self, n: int):
|
||||
self.prompt_token_ids = list(range(n))
|
||||
|
||||
return _Out(n_tokens)
|
||||
|
||||
def test_none_check_is_noop(self):
|
||||
outs = [self._fake_output(100)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, None)
|
||||
|
||||
def test_within_limit(self):
|
||||
outs = [self._fake_output(5), self._fake_output(3)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exceeds_limit(self):
|
||||
outs = [self._fake_output(3), self._fake_output(10)]
|
||||
with pytest.raises(ValueError, match="exceeds max_tokens=5"):
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exact_limit(self):
|
||||
outs = [self._fake_output(5)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
|
||||
class TestValidateInputType:
|
||||
"""Unit tests for EmbedIOProcessor._validate_input_type."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_none_input_type_always_accepted(self):
|
||||
handler = self._make_handler(None)
|
||||
handler._validate_input_type(None)
|
||||
handler_with = self._make_handler({"query": "q: "})
|
||||
handler_with._validate_input_type(None)
|
||||
|
||||
def test_no_prompts_rejects(self):
|
||||
handler = self._make_handler(None)
|
||||
with pytest.raises(ValueError, match="does not define any input_type"):
|
||||
handler._validate_input_type("anything")
|
||||
|
||||
def test_known_type_accepted(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
handler._validate_input_type("query")
|
||||
handler._validate_input_type("document")
|
||||
|
||||
def test_unknown_type_rejected(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
with pytest.raises(ValueError, match="Unsupported input_type 'other'"):
|
||||
handler._validate_input_type("other")
|
||||
|
||||
def test_error_lists_supported(self):
|
||||
handler = self._make_handler({"a": "", "b": ""})
|
||||
with pytest.raises(ValueError, match="Supported values: a, b"):
|
||||
handler._validate_input_type("z")
|
||||
129
tests/entrypoints/pooling/embed/test_protocol.py
Normal file
129
tests/entrypoints/pooling/embed/test_protocol.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for Cohere embed protocol: build_typed_embeddings and its
|
||||
underlying packing helpers, plus Cohere-specific serving helpers."""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
build_typed_embeddings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings() -> list[list[float]]:
|
||||
return [
|
||||
[0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8],
|
||||
[-0.05, 0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75],
|
||||
]
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsFloat:
|
||||
def test_float_passthrough(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float"])
|
||||
assert result.float == sample_embeddings
|
||||
assert result.binary is None
|
||||
|
||||
def test_empty_input(self):
|
||||
result = build_typed_embeddings([], ["float"])
|
||||
assert result.float == []
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsBinary:
|
||||
def test_binary_packing(self):
|
||||
# 8 values: positive->1, negative->0 => bits: 10101010 = 0xAA = 170
|
||||
# signed: 170 - 128 = 42
|
||||
embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
assert result.binary[0] == [42]
|
||||
|
||||
def test_ubinary_packing(self):
|
||||
embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
|
||||
result = build_typed_embeddings(embs, ["ubinary"])
|
||||
assert result.ubinary is not None
|
||||
assert result.ubinary[0] == [170] # 0b10101010
|
||||
|
||||
def test_binary_all_positive(self):
|
||||
embs = [[0.1] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# all bits = 1 => 0xFF = 255, signed: 255 - 128 = 127
|
||||
assert result.binary[0] == [127]
|
||||
|
||||
def test_binary_all_negative(self):
|
||||
embs = [[-0.1] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# all bits = 0, signed: 0 - 128 = -128
|
||||
assert result.binary[0] == [-128]
|
||||
|
||||
def test_binary_dimension_is_eighth(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["binary"])
|
||||
assert result.binary is not None
|
||||
for orig, packed in zip(sample_embeddings, result.binary):
|
||||
assert len(packed) == len(orig) // 8
|
||||
|
||||
def test_zero_treated_as_positive(self):
|
||||
embs = [[0.0] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# 0.0 >= 0 is True, so bit=1 for all => 127 (signed)
|
||||
assert result.binary[0] == [127]
|
||||
|
||||
def test_non_multiple_of_8_raises(self):
|
||||
embs = [[0.1] * 7]
|
||||
with pytest.raises(ValueError, match="multiple of 8"):
|
||||
build_typed_embeddings(embs, ["binary"])
|
||||
|
||||
def test_ubinary_non_multiple_of_8_raises(self):
|
||||
embs = [[0.1] * 10]
|
||||
with pytest.raises(ValueError, match="multiple of 8"):
|
||||
build_typed_embeddings(embs, ["ubinary"])
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsBase64:
|
||||
def test_base64_roundtrip(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["base64"])
|
||||
assert result.base64 is not None
|
||||
assert len(result.base64) == 2
|
||||
|
||||
for orig, b64_str in zip(sample_embeddings, result.base64):
|
||||
decoded = base64.b64decode(b64_str)
|
||||
n = len(orig)
|
||||
values = struct.unpack(f"<{n}f", decoded)
|
||||
np.testing.assert_allclose(orig, values, rtol=1e-5)
|
||||
|
||||
def test_base64_byte_length(self):
|
||||
embs = [[0.1, 0.2, 0.3]]
|
||||
result = build_typed_embeddings(embs, ["base64"])
|
||||
assert result.base64 is not None
|
||||
raw = base64.b64decode(result.base64[0])
|
||||
assert len(raw) == 3 * 4 # 3 floats * 4 bytes each
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsMultiple:
|
||||
def test_all_types_at_once(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(
|
||||
sample_embeddings,
|
||||
["float", "binary", "ubinary", "base64"],
|
||||
)
|
||||
assert result.float is not None
|
||||
assert result.binary is not None
|
||||
assert result.ubinary is not None
|
||||
assert result.base64 is not None
|
||||
|
||||
def test_subset_types(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float", "binary"])
|
||||
assert result.float is not None
|
||||
assert result.binary is not None
|
||||
assert result.ubinary is None
|
||||
assert result.base64 is None
|
||||
|
||||
def test_unknown_type_ignored(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float", "unknown_type"])
|
||||
assert result.float is not None
|
||||
434
tests/kernels/attention/test_trtllm_kvfp8_dequant.py
Normal file
434
tests/kernels/attention/test_trtllm_kvfp8_dequant.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Standalone unit tests for trtllm_prefill_attn_kvfp8_dequant.
|
||||
|
||||
Tests both contiguous and non-contiguous (cross-layer unified) KV cache
|
||||
layouts against a pure-PyTorch reference implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
NUM_BLOCKS = 128
|
||||
|
||||
|
||||
def to_float8(x, dtype=None):
|
||||
if dtype is None:
|
||||
dtype = FP8_DTYPE
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
def make_contiguous_kv_cache(num_blocks, num_kv_heads, block_size, head_size):
|
||||
"""Create a standard contiguous fp8 KV cache (HND layout)."""
|
||||
raw = torch.randn(
|
||||
num_blocks,
|
||||
2,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
kv_cache, scale = to_float8(raw)
|
||||
return kv_cache, scale
|
||||
|
||||
|
||||
def make_cross_layer_kv_cache(
|
||||
num_blocks,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
num_layers=4,
|
||||
):
|
||||
"""
|
||||
Create a non-contiguous per-layer view mimicking cross-layer allocation.
|
||||
|
||||
Physical layout: (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
|
||||
Returned view: (num_blocks, 2, num_kv_heads, block_size, head_size)
|
||||
with non-contiguous strides on dims 0, 1, 2 (they skip over num_layers).
|
||||
"""
|
||||
raw = torch.randn(
|
||||
num_blocks,
|
||||
2,
|
||||
num_kv_heads,
|
||||
num_layers,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
fp8_full, scale = to_float8(raw)
|
||||
layer_view = fp8_full[:, :, :, 0, :, :]
|
||||
assert not layer_view.is_contiguous(), (
|
||||
f"Expected non-contiguous view, got strides {layer_view.stride()}"
|
||||
)
|
||||
return layer_view, scale
|
||||
|
||||
|
||||
def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype):
|
||||
"""Pure PyTorch reference: gather pages and dequantize fp8 -> dequant_dtype."""
|
||||
batch_size, num_pages_per_seq = block_tables.shape
|
||||
s = kv_cache.shape
|
||||
out = torch.zeros(
|
||||
batch_size * num_pages_per_seq + 1,
|
||||
s[1],
|
||||
s[2],
|
||||
s[3],
|
||||
s[4],
|
||||
dtype=dequant_dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
for b in range(batch_size):
|
||||
for p in range(num_pages_per_seq):
|
||||
page_idx = block_tables[b, p].item()
|
||||
if page_idx <= 0:
|
||||
continue
|
||||
mock_idx = b * num_pages_per_seq + p + 1
|
||||
out[mock_idx, 0] = (kv_cache[page_idx, 0].float() * k_scale.item()).to(
|
||||
dequant_dtype
|
||||
)
|
||||
out[mock_idx, 1] = (kv_cache[page_idx, 1].float() * v_scale.item()).to(
|
||||
dequant_dtype
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_kv_heads", [1, 8])
|
||||
@pytest.mark.parametrize("head_size", [64, 128])
|
||||
@pytest.mark.parametrize("block_size", [16, 32])
|
||||
@pytest.mark.parametrize("batch_size", [1, 4])
|
||||
@pytest.mark.parametrize("num_pages_per_seq", [3, 8])
|
||||
@pytest.mark.parametrize("contiguous", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_trtllm_kvfp8_dequant(
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
batch_size: int,
|
||||
num_pages_per_seq: int,
|
||||
contiguous: bool,
|
||||
):
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
if contiguous:
|
||||
kv_cache, scale = make_contiguous_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
else:
|
||||
kv_cache, scale = make_cross_layer_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
|
||||
k_scale = scale.clone()
|
||||
v_scale = scale.clone()
|
||||
|
||||
block_tables = torch.randint(
|
||||
1,
|
||||
NUM_BLOCKS,
|
||||
(batch_size, num_pages_per_seq),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
|
||||
|
||||
expected_bt = torch.arange(
|
||||
1,
|
||||
batch_size * num_pages_per_seq + 1,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
).reshape(batch_size, num_pages_per_seq)
|
||||
torch.testing.assert_close(mock_block_table, expected_bt)
|
||||
|
||||
# Page 0 is padding (never written), compare only pages 1+
|
||||
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_block_tables_with_zero_pages():
|
||||
"""Pages with index <= 0 must be skipped (early return in kernel)."""
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_kv_heads, block_size, head_size = 8, 16, 64
|
||||
|
||||
kv_cache, scale = make_contiguous_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
k_scale = v_scale = scale.clone()
|
||||
|
||||
# Mix of valid pages and zeros (padding)
|
||||
block_tables = torch.tensor(
|
||||
[[5, 0, 10], [0, 0, 0], [3, 7, 0]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
|
||||
|
||||
# Only compare pages that were actually written (non-zero page indices)
|
||||
for b in range(block_tables.shape[0]):
|
||||
for p in range(block_tables.shape[1]):
|
||||
if block_tables[b, p].item() > 0:
|
||||
idx = b * block_tables.shape[1] + p + 1
|
||||
torch.testing.assert_close(
|
||||
mock_kv_cache[idx],
|
||||
ref[idx],
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_all_zero_block_tables():
|
||||
"""All-zero block_tables: kernel should write nothing."""
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_kv_heads, block_size, head_size = 4, 16, 64
|
||||
|
||||
kv_cache, scale = make_contiguous_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
k_scale = v_scale = scale.clone()
|
||||
|
||||
block_tables = torch.zeros(2, 4, dtype=torch.int32, device="cuda")
|
||||
|
||||
# Should not crash even though no pages are valid
|
||||
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
assert mock_kv_cache.shape[0] == 2 * 4 + 1
|
||||
assert mock_block_table.shape == (2, 4)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_different_k_v_scales():
|
||||
"""Verify K and V are dequantized with independent scales."""
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_kv_heads, block_size, head_size = 8, 16, 64
|
||||
|
||||
kv_cache, _ = make_contiguous_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
k_scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
v_scale = torch.tensor([2.0], dtype=torch.float32, device="cuda")
|
||||
|
||||
block_tables = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda")
|
||||
|
||||
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_single_page_per_seq():
|
||||
"""Minimum grid dim 1 = 1 page per sequence."""
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_kv_heads, block_size, head_size = 8, 16, 128
|
||||
|
||||
kv_cache, scale = make_contiguous_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
k_scale = v_scale = scale.clone()
|
||||
|
||||
block_tables = torch.tensor([[5], [10], [20]], dtype=torch.int32, device="cuda")
|
||||
|
||||
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_large_page_indices():
|
||||
"""Page indices near the top of the buffer stress offset arithmetic."""
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_kv_heads, block_size, head_size = 8, 16, 128
|
||||
large_num_blocks = 32768
|
||||
|
||||
kv_cache, scale = make_contiguous_kv_cache(
|
||||
large_num_blocks,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
k_scale = v_scale = scale.clone()
|
||||
|
||||
# Use page indices near the top of the buffer
|
||||
block_tables = torch.tensor(
|
||||
[[large_num_blocks - 1, large_num_blocks - 2, 1]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_large_block_size():
|
||||
"""block_size=64 -> HEAD_STRIDE=8192, large tl.arange per thread block."""
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_kv_heads, block_size, head_size = 4, 64, 128
|
||||
|
||||
kv_cache, scale = make_contiguous_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
)
|
||||
k_scale = v_scale = scale.clone()
|
||||
|
||||
block_tables = torch.randint(
|
||||
1,
|
||||
NUM_BLOCKS,
|
||||
(2, 4),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_cross_layer_many_layers():
|
||||
"""
|
||||
Non-contiguous with 36 layers -- matches real gpt-oss-120b.
|
||||
Strides are far from contiguous (factor of 36 in the gaps).
|
||||
"""
|
||||
from vllm.v1.attention.backends.flashinfer import (
|
||||
trtllm_prefill_attn_kvfp8_dequant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_kv_heads, block_size, head_size = 8, 16, 64
|
||||
num_layers = 36
|
||||
|
||||
kv_cache, scale = make_cross_layer_kv_cache(
|
||||
NUM_BLOCKS,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
head_size,
|
||||
num_layers=num_layers,
|
||||
)
|
||||
k_scale = v_scale = scale.clone()
|
||||
|
||||
block_tables = torch.randint(
|
||||
1,
|
||||
NUM_BLOCKS,
|
||||
(4, 6),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache,
|
||||
block_tables,
|
||||
k_scale,
|
||||
v_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
|
||||
@@ -280,21 +280,22 @@ def test_rms_norm(
|
||||
assert torch.allclose(ref_residual, ops_residual)
|
||||
|
||||
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
scales = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if group_size is None:
|
||||
scales = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant,
|
||||
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
|
||||
)
|
||||
else:
|
||||
# TODO(luka/eliza) opcheck is broken?
|
||||
# Somehow the cloned args are getting mutated in-place,
|
||||
# which causes the opcheck to fail.
|
||||
# https://github.com/vllm-project/vllm/issues/36688
|
||||
return
|
||||
assert hidden_size % group_size[1] == 0
|
||||
num_groups = hidden_size // group_size[1]
|
||||
scales = torch.empty(
|
||||
(num_groups, num_tokens),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).transpose(0, 1)
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_per_block_quant,
|
||||
(
|
||||
|
||||
@@ -159,6 +159,52 @@ def test_quantize_to_fp4(
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[(32, 4096), (128, 4096), (1, 64), (127, 1024), (256, 16384)],
|
||||
)
|
||||
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_python_util_matches_cpp_allocation(
|
||||
shape: tuple[int, int],
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the Python utility (create_fp4_output_tensors) allocates
|
||||
tensors with the same shapes and dtypes as the C++ functional variant
|
||||
(scaled_fp4_quant_func).
|
||||
"""
|
||||
from vllm._custom_ops import create_fp4_output_tensors
|
||||
|
||||
torch.set_default_device("cuda:0")
|
||||
m, n = shape
|
||||
input_tensor = torch.randn((m, n), dtype=torch.bfloat16)
|
||||
input_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda:0")
|
||||
|
||||
# C++ functional variant allocates internally
|
||||
cpp_out, cpp_scale = torch.ops._C.scaled_fp4_quant(
|
||||
input_tensor, input_scale, is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
# Python utility
|
||||
py_out, py_scale = create_fp4_output_tensors(
|
||||
m, n, torch.device("cuda:0"), is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
assert py_out.shape == cpp_out.shape, (
|
||||
f"Output shape mismatch: Python {py_out.shape} vs C++ {cpp_out.shape}"
|
||||
)
|
||||
assert py_out.dtype == cpp_out.dtype, (
|
||||
f"Output dtype mismatch: Python {py_out.dtype} vs C++ {cpp_out.dtype}"
|
||||
)
|
||||
assert py_scale.shape == cpp_scale.shape, (
|
||||
f"Scale shape mismatch: Python {py_scale.shape} vs C++ {cpp_scale.shape}"
|
||||
)
|
||||
assert py_scale.dtype == cpp_scale.dtype, (
|
||||
f"Scale dtype mismatch: Python {py_scale.dtype} vs C++ {cpp_scale.dtype}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
|
||||
104
tests/models/quantization/test_mxfp8.py
Normal file
104
tests/models/quantization/test_mxfp8.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""E2E tests for online MXFP8 quantization.
|
||||
|
||||
Loads a BF16 model with ``--quantization mxfp8`` (online quantization) and
|
||||
compares log-probabilities against the same model served in BF16 without
|
||||
quantization. This exercises the full pipeline: config parsing,
|
||||
``Mxfp8OnlineLinearMethod``, ``Mxfp8OnlineMoEMethod``, weight loading,
|
||||
online quantization / shuffling, and inference through ``apply_monolithic``.
|
||||
|
||||
Layer skipping (``modules_to_not_convert``) is configured in the model's
|
||||
``config.json`` under ``quantization_config`` and is not tested here.
|
||||
|
||||
``example_prompts`` is a pytest fixture (from conftest.py) that loads 8
|
||||
diverse prompts from ``tests/prompts/example.txt``.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ..utils import check_logprobs_close
|
||||
|
||||
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
|
||||
MOE_MODEL = "Qwen/Qwen3-30B-A3B"
|
||||
# A small dense model (no MoE) to validate the linear-only path.
|
||||
DENSE_MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
MAX_TOKENS = 4
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("mxfp8"),
|
||||
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
|
||||
)
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
|
||||
def test_mxfp8_logprobs(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Compare BF16 baseline logprobs against online MXFP8-quantized model.
|
||||
|
||||
Runs the same model twice -- once in BF16 (baseline) and once with
|
||||
online MXFP8 quantization -- then checks that the top log-probabilities
|
||||
are close. Only 4 tokens are generated to keep the test fast while
|
||||
still catching numerical divergence.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
quantization="mxfp8",
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=test_outputs,
|
||||
name_0="bf16",
|
||||
name_1="mxfp8",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("mxfp8"),
|
||||
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
|
||||
)
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
|
||||
def test_mxfp8_generation(vllm_runner, model: str) -> None:
|
||||
"""Smoke test: verify online MXFP8 model generates coherent text."""
|
||||
prompt = "1 2 3 4 5"
|
||||
with vllm_runner(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
quantization="mxfp8",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
) as vllm_model:
|
||||
output = vllm_model.generate_greedy([prompt], max_tokens=5)
|
||||
|
||||
generated = output[0][1]
|
||||
assert len(generated) > len(prompt), (
|
||||
f"MXFP8 model produced no new tokens. Output: {generated!r}"
|
||||
)
|
||||
378
tests/tool_parsers/common_tests.py
Normal file
378
tests/tool_parsers/common_tests.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from types import NoneType
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.utils import run_tool_extraction
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParserTestConfig:
|
||||
"""Configuration for a tool parser's common tests.
|
||||
|
||||
This dataclass contains all the test data and expected results needed
|
||||
to run the common test suite for a parser. Each parser test file
|
||||
creates one instance of this config with parser-specific values.
|
||||
|
||||
Attributes:
|
||||
parser_name: Name used with ToolParserManager (e.g., "mistral")
|
||||
|
||||
Test data (model outputs):
|
||||
no_tool_calls_output: Plain text without any tool syntax
|
||||
single_tool_call_output: One tool call with simple arguments
|
||||
parallel_tool_calls_output: Multiple tool calls in one response
|
||||
various_data_types_output: Tool with various data types
|
||||
empty_arguments_output: Tool call with no parameters
|
||||
surrounding_text_output: Tool call mixed with regular text
|
||||
escaped_strings_output: Tool call with escaped chars
|
||||
malformed_input_outputs: List of invalid inputs
|
||||
|
||||
Expected results:
|
||||
single_tool_call_expected_name: Expected function name
|
||||
single_tool_call_expected_args: Expected arguments dict
|
||||
parallel_tool_calls_count: Number of tools in parallel test
|
||||
parallel_tool_calls_names: Function names in order
|
||||
single_tool_call_expected_content: Content field when tool called
|
||||
parallel_tool_calls_expected_content: Content for parallel test
|
||||
|
||||
xfail markers:
|
||||
xfail_streaming: Mapping test name to xfail reason (streaming only)
|
||||
xfail_nonstreaming: Mapping test name to xfail reason (non-streaming)
|
||||
|
||||
Special flags:
|
||||
allow_empty_or_json_empty_args: True if "" or "{}" both valid for empty args
|
||||
supports_typed_arguments: True if the parser supports typed function arguments
|
||||
"""
|
||||
|
||||
# Parser identification
|
||||
parser_name: str
|
||||
|
||||
# Test data - model outputs for each common test
|
||||
no_tool_calls_output: str
|
||||
single_tool_call_output: str
|
||||
parallel_tool_calls_output: str
|
||||
various_data_types_output: str
|
||||
empty_arguments_output: str
|
||||
surrounding_text_output: str
|
||||
escaped_strings_output: str
|
||||
malformed_input_outputs: list[str]
|
||||
|
||||
# Expected results for specific tests (optional overrides)
|
||||
single_tool_call_expected_name: str = "get_weather"
|
||||
single_tool_call_expected_args: dict[str, Any] = field(
|
||||
default_factory=lambda: {"city": "Tokyo"}
|
||||
)
|
||||
parallel_tool_calls_count: int = 2
|
||||
parallel_tool_calls_names: list[str] = field(
|
||||
default_factory=lambda: ["get_weather", "get_time"]
|
||||
)
|
||||
|
||||
# xfail configuration - maps test name to xfail reason
|
||||
xfail_streaming: dict[str, str] = field(default_factory=dict)
|
||||
xfail_nonstreaming: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Content expectations (some parsers strip content, others don't)
|
||||
single_tool_call_expected_content: str | None = None
|
||||
parallel_tool_calls_expected_content: str | None = None
|
||||
|
||||
# Special assertions for edge cases
|
||||
allow_empty_or_json_empty_args: bool = True # "{}" or "" for empty args
|
||||
supports_typed_arguments: bool = True
|
||||
|
||||
|
||||
class ToolParserTests:
|
||||
"""Mixin class providing common test suite for tool parsers.
|
||||
|
||||
To use this mixin in a parser test file:
|
||||
|
||||
1. Create a test_config fixture that returns a ToolParserTestConfig instance
|
||||
2. Inherit from this class
|
||||
3. Add parser-specific tests as additional methods
|
||||
|
||||
Example:
|
||||
class TestMistralToolParser(ToolParserTests):
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="mistral",
|
||||
no_tool_calls_output="Plain text...",
|
||||
# ... other config ...
|
||||
)
|
||||
|
||||
# Parser-specific tests
|
||||
def test_mistral_specific_feature(self, tool_parser):
|
||||
# Custom test logic
|
||||
pass
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
"""Override this to provide parser-specific configuration."""
|
||||
raise NotImplementedError(
|
||||
"Subclass must provide test_config fixture returning ToolParserTestConfig"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
|
||||
"""Override this to provide parser-specific tokenizer."""
|
||||
return default_tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def tool_parser(self, test_config: ToolParserTestConfig, tokenizer: TokenizerLike):
|
||||
return ToolParserManager.get_tool_parser(test_config.parser_name)(tokenizer)
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def streaming(self, request: pytest.FixtureRequest) -> bool:
|
||||
return request.param
|
||||
|
||||
def test_no_tool_calls(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser handles plain text without tool syntax."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_no_tool_calls"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, test_config.no_tool_calls_output, streaming=streaming
|
||||
)
|
||||
assert content == test_config.no_tool_calls_output, (
|
||||
f"Expected content to match input, got {content}"
|
||||
)
|
||||
assert len(tool_calls) == 0, f"Expected no tool calls, got {len(tool_calls)}"
|
||||
|
||||
def test_single_tool_call_simple_args(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser extracts one tool with simple arguments."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_single_tool_call_simple_args"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, test_config.single_tool_call_output, streaming=streaming
|
||||
)
|
||||
|
||||
# Content check (some parsers strip it)
|
||||
if test_config.single_tool_call_expected_content is not None:
|
||||
assert content == test_config.single_tool_call_expected_content
|
||||
|
||||
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
||||
assert tool_calls[0].type == "function"
|
||||
assert tool_calls[0].function.name == test_config.single_tool_call_expected_name
|
||||
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
for key, value in test_config.single_tool_call_expected_args.items():
|
||||
assert args.get(key) == value, (
|
||||
f"Expected {key}={value}, got {args.get(key)}"
|
||||
)
|
||||
|
||||
def test_parallel_tool_calls(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser handles multiple tools in one response."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_parallel_tool_calls"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser,
|
||||
test_config.parallel_tool_calls_output,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
assert len(tool_calls) == test_config.parallel_tool_calls_count, (
|
||||
f"Expected {test_config.parallel_tool_calls_count} "
|
||||
f"tool calls, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
# Verify tool names match expected
|
||||
for i, expected_name in enumerate(test_config.parallel_tool_calls_names):
|
||||
assert tool_calls[i].type == "function"
|
||||
assert tool_calls[i].function.name == expected_name
|
||||
|
||||
# Verify unique IDs
|
||||
ids = [tc.id for tc in tool_calls]
|
||||
assert len(ids) == len(set(ids)), "Tool call IDs should be unique"
|
||||
|
||||
def test_various_data_types(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser handles all JSON types in arguments."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_various_data_types"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser,
|
||||
test_config.various_data_types_output,
|
||||
streaming=streaming,
|
||||
)
|
||||
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
||||
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
# Verify all expected fields present
|
||||
required_fields_types = {
|
||||
"string_field": str,
|
||||
"int_field": int,
|
||||
"float_field": float,
|
||||
"bool_field": bool,
|
||||
"null_field": NoneType,
|
||||
"array_field": list,
|
||||
"object_field": dict,
|
||||
}
|
||||
for required_field, expected_type in required_fields_types.items():
|
||||
assert required_field in args, (
|
||||
f"Expected field '{required_field}' in arguments"
|
||||
)
|
||||
if test_config.supports_typed_arguments:
|
||||
found_type = type(args[required_field])
|
||||
assert found_type is expected_type, (
|
||||
f"Expected field '{required_field}' to have type {expected_type}, "
|
||||
f"got {found_type}"
|
||||
)
|
||||
|
||||
def test_empty_arguments(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser handles parameterless tool calls."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_empty_arguments"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, test_config.empty_arguments_output, streaming=streaming
|
||||
)
|
||||
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
||||
|
||||
args = tool_calls[0].function.arguments
|
||||
if test_config.allow_empty_or_json_empty_args:
|
||||
assert args in ["{}", ""], f"Expected empty args, got {args}"
|
||||
else:
|
||||
assert args == "{}", f"Expected {{}}, got {args}"
|
||||
|
||||
def test_surrounding_text(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser extracts tools from mixed content."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_surrounding_text"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, test_config.surrounding_text_output, streaming=streaming
|
||||
)
|
||||
assert len(tool_calls) >= 1, (
|
||||
f"Expected at least 1 tool call, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
def test_escaped_strings(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser handles escaped characters in arguments."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_escaped_strings"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, test_config.escaped_strings_output, streaming=streaming
|
||||
)
|
||||
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
||||
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
# At minimum, verify we can parse and have expected fields
|
||||
# Exact escaping behavior varies by parser
|
||||
assert len(args) > 0, "Expected some arguments with escaped strings"
|
||||
|
||||
def test_malformed_input(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
streaming: bool,
|
||||
):
|
||||
"""Verify parser gracefully handles invalid syntax."""
|
||||
# Apply xfail markers if configured
|
||||
test_name = "test_malformed_input"
|
||||
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
||||
|
||||
for malformed_input in test_config.malformed_input_outputs:
|
||||
# Should not raise exception
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, malformed_input, streaming=streaming
|
||||
)
|
||||
# Parser should handle gracefully (exact behavior varies)
|
||||
|
||||
def test_streaming_reconstruction(
|
||||
self,
|
||||
request: pytest.FixtureRequest,
|
||||
tool_parser: Any,
|
||||
test_config: ToolParserTestConfig,
|
||||
):
|
||||
"""Verify streaming produces same result as non-streaming."""
|
||||
test_name = "test_streaming_reconstruction"
|
||||
self.apply_xfail_mark(request, test_config, test_name, True)
|
||||
|
||||
test_output = test_config.single_tool_call_output
|
||||
|
||||
# Non-streaming result
|
||||
content_non, tools_non = run_tool_extraction(
|
||||
tool_parser, test_output, streaming=False
|
||||
)
|
||||
|
||||
# Streaming result
|
||||
content_stream, tools_stream = run_tool_extraction(
|
||||
tool_parser, test_output, streaming=True
|
||||
)
|
||||
|
||||
# Compare results
|
||||
assert content_non == content_stream, "Content should match between modes"
|
||||
assert len(tools_non) == len(tools_stream), "Tool count should match"
|
||||
if len(tools_non) > 0:
|
||||
assert tools_non[0].function.name == tools_stream[0].function.name
|
||||
assert tools_non[0].function.arguments == tools_stream[0].function.arguments
|
||||
|
||||
def apply_xfail_mark(self, request, test_config, test_name, streaming):
|
||||
reason = None
|
||||
if streaming and test_name in test_config.xfail_streaming:
|
||||
reason = test_config.xfail_streaming[test_name]
|
||||
elif not streaming and test_name in test_config.xfail_nonstreaming:
|
||||
reason = test_config.xfail_nonstreaming[test_name]
|
||||
if reason is not None:
|
||||
mark = pytest.mark.xfail(reason=reason, strict=True)
|
||||
request.node.add_marker(mark)
|
||||
12
tests/tool_parsers/conftest.py
Normal file
12
tests/tool_parsers/conftest.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_tokenizer() -> TokenizerLike:
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
92
tests/tool_parsers/test_deepseekv3_tool_parser.py
Normal file
92
tests/tool_parsers/test_deepseekv3_tool_parser.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
|
||||
|
||||
class TestDeepSeekV3ToolParser(ToolParserTests):
|
||||
@pytest.fixture(scope="class")
|
||||
def tokenizer(self) -> TokenizerLike:
|
||||
return get_tokenizer("deepseek-ai/DeepSeek-V3")
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="deepseek_v3",
|
||||
# Test data
|
||||
no_tool_calls_output=(
|
||||
"How can I help you today? I can check weather for you."
|
||||
),
|
||||
single_tool_call_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"city": "Tokyo", "unit": "celsius"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>""",
|
||||
parallel_tool_calls_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"city": "Tokyo", "unit": "celsius"}
|
||||
```<|tool▁call▁end|><|tool▁call▁begin|>function<|tool▁sep|>search_hotels
|
||||
```json
|
||||
{"location": "Tokyo", "check_in": "2025-01-15"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>""",
|
||||
various_data_types_output=(
|
||||
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test_function
|
||||
```json
|
||||
"""
|
||||
"""{"string_field": "hello", "int_field": 42, "float_field": 3.14, """
|
||||
""""bool_field": true, "null_field": null, """
|
||||
""""array_field": ["a", "b", "c"], """
|
||||
""""object_field": {"nested": "value"}, """
|
||||
""""empty_array": [], "empty_object": {}}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"""
|
||||
),
|
||||
empty_arguments_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_time
|
||||
```json
|
||||
{}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>""",
|
||||
surrounding_text_output=(
|
||||
"""Let me check the weather for you."""
|
||||
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"city": "Paris"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"""
|
||||
),
|
||||
escaped_strings_output=(
|
||||
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>send_message
|
||||
```json
|
||||
"""
|
||||
"""{"text": "He said \\"hello\\"", "path": "C:\\\\Users\\\\file", """
|
||||
""""newline": "line1\\nline2"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"""
|
||||
),
|
||||
malformed_input_outputs=[
|
||||
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"city": "Tokyo"
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>""",
|
||||
"""<|tool▁calls▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"city": "Tokyo"}
|
||||
```<|tool▁calls▁end|>""",
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo", "unit": "celsius"},
|
||||
single_tool_call_expected_content=None,
|
||||
parallel_tool_calls_count=2,
|
||||
parallel_tool_calls_names=["get_weather", "search_hotels"],
|
||||
# xfail markers
|
||||
xfail_streaming={},
|
||||
xfail_nonstreaming={
|
||||
"test_malformed_input": (
|
||||
"Parser sets tools_called=True even when tool_calls is "
|
||||
"empty (detects start token but fails to parse)"
|
||||
),
|
||||
},
|
||||
)
|
||||
76
tests/tool_parsers/test_granite_20b_fc_tool_parser.py
Normal file
76
tests/tool_parsers/test_granite_20b_fc_tool_parser.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
|
||||
|
||||
class TestGranite20bFcToolParser(ToolParserTests):
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="granite-20b-fc",
|
||||
# Test data
|
||||
no_tool_calls_output="This is a regular response without any tool calls.",
|
||||
single_tool_call_output=(
|
||||
'<function_call> {"name": "get_weather", '
|
||||
'"arguments": {"city": "Tokyo"}}'
|
||||
),
|
||||
parallel_tool_calls_output=(
|
||||
'<function_call> {"name": "get_weather", '
|
||||
'"arguments": {"city": "Tokyo"}}\n'
|
||||
'<function_call> {"name": "get_time", '
|
||||
'"arguments": {"timezone": "Asia/Tokyo"}}'
|
||||
),
|
||||
various_data_types_output="""<function_call> {
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"string_field": "hello",
|
||||
"int_field": 42,
|
||||
"float_field": 3.14,
|
||||
"bool_field": true,
|
||||
"null_field": null,
|
||||
"array_field": ["a", "b", "c"],
|
||||
"object_field": {"nested": "value"},
|
||||
"empty_array": [],
|
||||
"empty_object": {}
|
||||
}
|
||||
}""",
|
||||
empty_arguments_output=(
|
||||
'<function_call> {"name": "refresh", "arguments": {}}'
|
||||
),
|
||||
surrounding_text_output="""Let me check the weather for you.
|
||||
<function_call> {"name": "get_weather", "arguments": {"city": "Tokyo"}}""",
|
||||
escaped_strings_output="""<function_call> {
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"quoted": "He said \\"hello\\"",
|
||||
"path": "C:\\\\Users\\\\file.txt",
|
||||
"newline": "line1\\nline2",
|
||||
"unicode": "emoji: 🎉"
|
||||
}
|
||||
}""",
|
||||
malformed_input_outputs=[
|
||||
'<function_call> {"name": "func", "arguments": {',
|
||||
'<function_call> [{"name": "func", "arguments": {}}]',
|
||||
'{"name": "func", "arguments": {}}',
|
||||
'<function_call> {"name": 123}',
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo"},
|
||||
single_tool_call_expected_content=None,
|
||||
parallel_tool_calls_count=2,
|
||||
parallel_tool_calls_names=["get_weather", "get_time"],
|
||||
# xfail markers
|
||||
xfail_streaming={
|
||||
"test_surrounding_text": (
|
||||
"Granite 20B FC streaming requires <function_call> at start"
|
||||
),
|
||||
},
|
||||
xfail_nonstreaming={},
|
||||
)
|
||||
118
tests/tool_parsers/test_granite_tool_parser.py
Normal file
118
tests/tool_parsers/test_granite_tool_parser.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
from tests.tool_parsers.utils import run_tool_extraction
|
||||
|
||||
|
||||
class TestGraniteToolParser(ToolParserTests):
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="granite",
|
||||
# Test data
|
||||
no_tool_calls_output="This is a regular response without any tool calls.",
|
||||
single_tool_call_output=(
|
||||
'<|tool_call|> [{"name": "get_weather", '
|
||||
'"arguments": {"city": "Tokyo"}}]'
|
||||
),
|
||||
parallel_tool_calls_output="""<|tool_call|> [
|
||||
{"name": "get_weather", "arguments": {"city": "Tokyo"}},
|
||||
{"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}}
|
||||
]""",
|
||||
various_data_types_output="""<tool_call> [{
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"string_field": "hello",
|
||||
"int_field": 42,
|
||||
"float_field": 3.14,
|
||||
"bool_field": true,
|
||||
"null_field": null,
|
||||
"array_field": ["a", "b", "c"],
|
||||
"object_field": {"nested": "value"},
|
||||
"empty_array": [],
|
||||
"empty_object": {}
|
||||
}
|
||||
}]""",
|
||||
empty_arguments_output=(
|
||||
'<|tool_call|> [{"name": "refresh", "arguments": {}}]'
|
||||
),
|
||||
surrounding_text_output="""Let me check the weather for you.
|
||||
<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]
|
||||
I'll get that information.""",
|
||||
escaped_strings_output="""<tool_call> [{
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"quoted": "He said \\"hello\\"",
|
||||
"path": "C:\\\\Users\\\\file.txt",
|
||||
"newline": "line1\\nline2",
|
||||
"unicode": "emoji: 🎉"
|
||||
}
|
||||
}]""",
|
||||
malformed_input_outputs=[
|
||||
'<|tool_call|> [{"name": "func", "arguments": {',
|
||||
'<|tool_call|> {"name": "func", "arguments": {}}', # Not an array
|
||||
'[{"name": "func", "arguments": "not a dict"}]',
|
||||
'Some text [{"name": "func"}]', # JSON but not tool call format
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo"},
|
||||
# Granite strips content when tool calls present
|
||||
single_tool_call_expected_content=None,
|
||||
parallel_tool_calls_count=2,
|
||||
parallel_tool_calls_names=["get_weather", "get_time"],
|
||||
# xfail markers
|
||||
xfail_streaming={
|
||||
"test_malformed_input": (
|
||||
"Streaming mode incorrectly creates tool call from malformed JSON"
|
||||
),
|
||||
"test_surrounding_text": (
|
||||
"Parser doesn't handle surrounding text correctly in streaming"
|
||||
),
|
||||
"test_streaming_reconstruction": (
|
||||
"Streaming mode doesn't strip <|tool_call|> marker from content"
|
||||
),
|
||||
},
|
||||
xfail_nonstreaming={
|
||||
"test_surrounding_text": (
|
||||
"Parser doesn't handle surrounding text correctly in non-streaming"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Granite-Specific Tests
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_granite_token_prefix_format(self, tool_parser, streaming):
|
||||
"""Verify parser handles Granite 3.0 <|tool_call|> token format."""
|
||||
single_tool_call_token = (
|
||||
'<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, single_tool_call_token, streaming=streaming
|
||||
)
|
||||
assert len(tool_calls) == 1, (
|
||||
f"Expected 1 tool call from token format, got {len(tool_calls)}"
|
||||
)
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_granite_string_prefix_format(self, tool_parser, streaming):
|
||||
"""Verify parser handles Granite 3.1 <tool_call> string format."""
|
||||
single_tool_call_string = (
|
||||
'<tool_call> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, single_tool_call_string, streaming=streaming
|
||||
)
|
||||
assert len(tool_calls) == 1, (
|
||||
f"Expected 1 tool call from string format, got {len(tool_calls)}"
|
||||
)
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
122
tests/tool_parsers/test_internlm2_tool_parser.py
Normal file
122
tests/tool_parsers/test_internlm2_tool_parser.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
class TestInternLM2ToolParser(ToolParserTests):
|
||||
@pytest.fixture
|
||||
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
|
||||
"""Add some internlm2 specific tokens to the default vocab."""
|
||||
|
||||
tokenizer_vocab = default_tokenizer.get_vocab()
|
||||
default_tokenizer.get_vocab = MagicMock()
|
||||
tokenizer_vocab.update(
|
||||
{
|
||||
"<|action_start|>": 92540,
|
||||
"<|plugin|>": 92541,
|
||||
"<|action_end|>": 92542,
|
||||
}
|
||||
)
|
||||
default_tokenizer.get_vocab.return_value = tokenizer_vocab
|
||||
return default_tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="internlm",
|
||||
# Test data
|
||||
no_tool_calls_output="This is a regular response without any tool calls.",
|
||||
single_tool_call_output=(
|
||||
'<|action_start|><|plugin|>{"name": "get_weather", '
|
||||
'"parameters": {"city": "Tokyo"}}<|action_end|>'
|
||||
),
|
||||
# InternLM2 doesn't support parallel calls
|
||||
parallel_tool_calls_output=(
|
||||
'<|action_start|><|plugin|>{"name": "get_weather", '
|
||||
'"parameters": {"city": "Tokyo"}}<|action_end|>'
|
||||
),
|
||||
various_data_types_output="""<|action_start|><|plugin|>{
|
||||
"name": "test_function",
|
||||
"parameters": {
|
||||
"string_field": "hello",
|
||||
"int_field": 42,
|
||||
"float_field": 3.14,
|
||||
"bool_field": true,
|
||||
"null_field": null,
|
||||
"array_field": ["a", "b", "c"],
|
||||
"object_field": {"nested": "value"},
|
||||
"empty_array": [],
|
||||
"empty_object": {}
|
||||
}
|
||||
}<|action_end|>""",
|
||||
empty_arguments_output=(
|
||||
'<|action_start|><|plugin|>{"name": "refresh", '
|
||||
'"parameters": {}}<|action_end|>'
|
||||
),
|
||||
surrounding_text_output=(
|
||||
"Let me check the weather for you. "
|
||||
'<|action_start|><|plugin|>{"name": "get_weather", '
|
||||
'"parameters": {"city": "Tokyo"}}<|action_end|>'
|
||||
),
|
||||
escaped_strings_output="""<|action_start|><|plugin|>{
|
||||
"name": "test_function",
|
||||
"parameters": {
|
||||
"quoted": "He said \\"hello\\"",
|
||||
"path": "C:\\\\Users\\\\file.txt",
|
||||
"newline": "line1\\nline2",
|
||||
"unicode": "emoji: 🎉"
|
||||
}
|
||||
}<|action_end|>""",
|
||||
malformed_input_outputs=[
|
||||
'<|action_start|><|plugin|>{"name": "func", "parameters": {',
|
||||
(
|
||||
'<|action_start|><|plugin|>{"name": "func", '
|
||||
'"parameters": "not a dict"}<|action_end|>'
|
||||
),
|
||||
"<|action_start|><|plugin|>not json<|action_end|>",
|
||||
"<|action_start|><|plugin|>",
|
||||
'<|action_start|>{"name": "func"}',
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo"},
|
||||
single_tool_call_expected_content=None,
|
||||
parallel_tool_calls_count=1, # InternLM2 only supports single tool calls
|
||||
parallel_tool_calls_names=["get_weather"],
|
||||
# Parser-specific settings
|
||||
allow_empty_or_json_empty_args=True,
|
||||
# xfail markers
|
||||
xfail_streaming={
|
||||
"test_single_tool_call_simple_args": (
|
||||
"InternLM2 streaming not fully implemented"
|
||||
),
|
||||
"test_parallel_tool_calls": (
|
||||
"InternLM2 streaming not fully implemented"
|
||||
),
|
||||
"test_various_data_types": (
|
||||
"InternLM2 streaming not fully implemented"
|
||||
),
|
||||
"test_empty_arguments": ("InternLM2 streaming not fully implemented"),
|
||||
"test_surrounding_text": ("InternLM2 streaming not fully implemented"),
|
||||
"test_escaped_strings": ("InternLM2 streaming not fully implemented"),
|
||||
"test_streaming_reconstruction": (
|
||||
"InternLM2 streaming parser returns '<|action_start|' as "
|
||||
"content instead of None - streaming/non-streaming inconsistency"
|
||||
),
|
||||
},
|
||||
xfail_nonstreaming={
|
||||
"test_malformed_input": (
|
||||
"InternLM2 parser raises JSONDecodeError on malformed JSON "
|
||||
"instead of gracefully handling it"
|
||||
),
|
||||
},
|
||||
)
|
||||
101
tests/tool_parsers/test_longcat_tool_parser.py
Normal file
101
tests/tool_parsers/test_longcat_tool_parser.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
class TestLongCatToolParser(ToolParserTests):
|
||||
@pytest.fixture
|
||||
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
|
||||
"""Add some longcat specific tokens to the default vocab."""
|
||||
tokenizer = default_tokenizer
|
||||
tokenizer_vocab = tokenizer.get_vocab()
|
||||
tokenizer.get_vocab = MagicMock()
|
||||
tokenizer_vocab.update(
|
||||
{
|
||||
"<longcat_tool_call>": 32000,
|
||||
"</longcat_tool_call>": 32001,
|
||||
}
|
||||
)
|
||||
tokenizer.get_vocab.return_value = tokenizer_vocab
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="longcat",
|
||||
# Test data
|
||||
no_tool_calls_output="This is a regular response without any tool calls.",
|
||||
single_tool_call_output=(
|
||||
'<longcat_tool_call>{"name": "get_weather", '
|
||||
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>'
|
||||
),
|
||||
parallel_tool_calls_output=(
|
||||
'<longcat_tool_call>{"name": "get_weather", '
|
||||
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>\n'
|
||||
'<longcat_tool_call>{"name": "get_time", '
|
||||
'"arguments": {"timezone": "Asia/Tokyo"}}</longcat_tool_call>'
|
||||
),
|
||||
various_data_types_output="""<longcat_tool_call>{
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"string_field": "hello",
|
||||
"int_field": 42,
|
||||
"float_field": 3.14,
|
||||
"bool_field": true,
|
||||
"null_field": null,
|
||||
"array_field": ["a", "b", "c"],
|
||||
"object_field": {"nested": "value"},
|
||||
"empty_array": [],
|
||||
"empty_object": {}
|
||||
}
|
||||
}</longcat_tool_call>""",
|
||||
empty_arguments_output=(
|
||||
'<longcat_tool_call>{"name": "refresh", "arguments": {}}'
|
||||
"</longcat_tool_call>"
|
||||
),
|
||||
surrounding_text_output=(
|
||||
"Let me check the weather for you.\n"
|
||||
'<longcat_tool_call>{"name": "get_weather", '
|
||||
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>\n'
|
||||
"Here is the result."
|
||||
),
|
||||
escaped_strings_output="""<longcat_tool_call>{
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"quoted": "He said \\"hello\\"",
|
||||
"path": "C:\\\\Users\\\\file.txt",
|
||||
"newline": "line1\\nline2",
|
||||
"unicode": "emoji: 🎉"
|
||||
}
|
||||
}</longcat_tool_call>""",
|
||||
malformed_input_outputs=[
|
||||
'<longcat_tool_call>{"name": "func", "arguments": {',
|
||||
(
|
||||
'<longcat_tool_call>{"name": "func", '
|
||||
'"arguments": "not a dict"}</longcat_tool_call>'
|
||||
),
|
||||
"Some text with <longcat_tool_call>invalid json",
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo"},
|
||||
single_tool_call_expected_content=None,
|
||||
parallel_tool_calls_count=2,
|
||||
parallel_tool_calls_names=["get_weather", "get_time"],
|
||||
# xfail markers
|
||||
xfail_streaming={
|
||||
"test_malformed_input": "Streaming has complex buffering behavior",
|
||||
},
|
||||
xfail_nonstreaming={},
|
||||
# Configuration
|
||||
allow_empty_or_json_empty_args=True,
|
||||
)
|
||||
110
tests/tool_parsers/test_phi4mini_tool_parser.py
Normal file
110
tests/tool_parsers/test_phi4mini_tool_parser.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
class TestPhi4MiniToolParser(ToolParserTests):
|
||||
@pytest.fixture
|
||||
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
|
||||
"""Add some phi4mini specific tokens to the default vocab."""
|
||||
|
||||
tokenizer = default_tokenizer
|
||||
tokenizer_vocab = tokenizer.get_vocab()
|
||||
tokenizer.get_vocab = MagicMock()
|
||||
tokenizer_vocab.update(
|
||||
{
|
||||
"functools": 32000,
|
||||
}
|
||||
)
|
||||
tokenizer.get_vocab.return_value = tokenizer_vocab
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="phi4_mini_json",
|
||||
# Test data
|
||||
no_tool_calls_output="This is a regular response without any tool calls.",
|
||||
single_tool_call_output=(
|
||||
'functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
|
||||
),
|
||||
parallel_tool_calls_output="""functools[
|
||||
{"name": "get_weather", "arguments": {"city": "Tokyo"}},
|
||||
{"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}}
|
||||
]""",
|
||||
various_data_types_output="""functools[{
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"string_field": "hello",
|
||||
"int_field": 42,
|
||||
"float_field": 3.14,
|
||||
"bool_field": true,
|
||||
"null_field": null,
|
||||
"array_field": ["a", "b", "c"],
|
||||
"object_field": {"nested": "value"},
|
||||
"empty_array": [],
|
||||
"empty_object": {}
|
||||
}
|
||||
}]""",
|
||||
empty_arguments_output='functools[{"name": "refresh", "arguments": {}}]',
|
||||
surrounding_text_output="""Let me check the weather for you.
|
||||
functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}]
|
||||
Would you like to know more?""",
|
||||
escaped_strings_output="""functools[{
|
||||
"name": "test_function",
|
||||
"arguments": {
|
||||
"quoted": "He said \\"hello\\"",
|
||||
"path": "C:\\\\Users\\\\file.txt",
|
||||
"newline": "line1\\nline2",
|
||||
"unicode": "emoji: 🎉"
|
||||
}
|
||||
}]""",
|
||||
malformed_input_outputs=[
|
||||
'functools[{"name": "func", "arguments": {',
|
||||
'functools[{"name": "func", "arguments": "not a dict"}]',
|
||||
'functools{"name": "func"}', # Missing brackets
|
||||
'functools[{"name": "func"}]', # Missing arguments/parameters
|
||||
"functools[] This is just text", # Empty functools
|
||||
"functools[ This is just text ]", # functools with invalid JSON
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo"},
|
||||
# Phi-4 Mini strips content when tool calls present
|
||||
single_tool_call_expected_content=None,
|
||||
parallel_tool_calls_count=2,
|
||||
parallel_tool_calls_names=["get_weather", "get_time"],
|
||||
parallel_tool_calls_expected_content=None,
|
||||
# xfail markers
|
||||
xfail_streaming={
|
||||
"test_no_tool_calls": "Phi4 Mini streaming not implemented",
|
||||
"test_single_tool_call_simple_args": (
|
||||
"Phi4 Mini streaming not implemented"
|
||||
),
|
||||
"test_parallel_tool_calls": "Phi4 Mini streaming not implemented",
|
||||
"test_various_data_types": "Phi4 Mini streaming not implemented",
|
||||
"test_empty_arguments": "Phi4 Mini streaming not implemented",
|
||||
"test_surrounding_text": "Phi4 Mini streaming not implemented",
|
||||
"test_escaped_strings": "Phi4 Mini streaming not implemented",
|
||||
"test_streaming_reconstruction": "Phi4 Mini streaming not implemented",
|
||||
},
|
||||
xfail_nonstreaming={
|
||||
"test_various_data_types": (
|
||||
"Phi4MiniJsonToolParser regex has nesting limitations "
|
||||
"with nested objects"
|
||||
),
|
||||
"test_malformed_input": (
|
||||
"Phi4MiniJsonToolParser incorrectly sets "
|
||||
"tools_called=True on empty array"
|
||||
),
|
||||
},
|
||||
)
|
||||
75
tests/tool_parsers/test_qwen3xml_tool_parser.py
Normal file
75
tests/tool_parsers/test_qwen3xml_tool_parser.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
|
||||
|
||||
class TestQwen3xmlToolParser(ToolParserTests):
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="qwen3_xml",
|
||||
# Test data
|
||||
no_tool_calls_output="This is a regular response without any tool calls.",
|
||||
single_tool_call_output="<tool_call>\n<function=get_weather>\n<parameter=city>Tokyo</parameter>\n</function>\n</tool_call>",
|
||||
parallel_tool_calls_output="<tool_call>\n<function=get_weather>\n<parameter=city>Tokyo</parameter>\n</function>\n</tool_call><tool_call>\n<function=get_time>\n<parameter=timezone>Asia/Tokyo</parameter>\n</function>\n</tool_call>",
|
||||
various_data_types_output=(
|
||||
"<tool_call>\n<function=test_function>\n"
|
||||
"<parameter=string_field>hello</parameter>\n"
|
||||
"<parameter=int_field>42</parameter>\n"
|
||||
"<parameter=float_field>3.14</parameter>\n"
|
||||
"<parameter=bool_field>true</parameter>\n"
|
||||
"<parameter=null_field>null</parameter>\n"
|
||||
'<parameter=array_field>["a", "b", "c"]</parameter>\n'
|
||||
'<parameter=object_field>{"nested": "value"}</parameter>\n'
|
||||
"</function>\n</tool_call>"
|
||||
),
|
||||
empty_arguments_output="<tool_call>\n<function=refresh>\n</function>\n</tool_call>",
|
||||
surrounding_text_output=(
|
||||
"Let me check the weather for you.\n\n"
|
||||
"<tool_call>\n<function=get_weather>\n"
|
||||
"<parameter=city>Tokyo</parameter>\n"
|
||||
"</function>\n</tool_call>\n\n"
|
||||
"I will get that information."
|
||||
),
|
||||
escaped_strings_output=(
|
||||
"<tool_call>\n<function=test_function>\n"
|
||||
'<parameter=quoted>He said "hello"</parameter>\n'
|
||||
"<parameter=path>C:\\Users\\file.txt</parameter>\n"
|
||||
"<parameter=newline>line1\nline2</parameter>\n"
|
||||
"</function>\n</tool_call>"
|
||||
),
|
||||
malformed_input_outputs=[
|
||||
"<tool_call><function=func>",
|
||||
"<tool_call><function=></function></tool_call>",
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo"},
|
||||
parallel_tool_calls_count=2,
|
||||
parallel_tool_calls_names=["get_weather", "get_time"],
|
||||
# xfail markers - Qwen3XML has systematic streaming issues
|
||||
xfail_streaming={
|
||||
"test_single_tool_call_simple_args": (
|
||||
"Qwen3XML streaming has systematic issues"
|
||||
),
|
||||
"test_parallel_tool_calls": "Qwen3XML streaming has systematic issues",
|
||||
"test_various_data_types": "Qwen3XML streaming has systematic issues",
|
||||
"test_empty_arguments": "Qwen3XML streaming has systematic issues",
|
||||
"test_surrounding_text": "Qwen3XML streaming has systematic issues",
|
||||
"test_escaped_strings": "Qwen3XML streaming has systematic issues",
|
||||
"test_malformed_input": (
|
||||
"Qwen3XML parser is lenient with malformed input"
|
||||
),
|
||||
"test_streaming_reconstruction": (
|
||||
"Qwen3XML streaming reconstruction has known issues"
|
||||
),
|
||||
},
|
||||
supports_typed_arguments=False,
|
||||
)
|
||||
112
tests/tool_parsers/test_step3_tool_parser.py
Normal file
112
tests/tool_parsers/test_step3_tool_parser.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.common_tests import (
|
||||
ToolParserTestConfig,
|
||||
ToolParserTests,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
|
||||
|
||||
class TestStep3ToolParser(ToolParserTests):
|
||||
@pytest.fixture(scope="class")
|
||||
def tokenizer(self) -> TokenizerLike:
|
||||
return get_tokenizer("stepfun-ai/step3")
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self) -> ToolParserTestConfig:
|
||||
return ToolParserTestConfig(
|
||||
parser_name="step3",
|
||||
# Test data
|
||||
no_tool_calls_output="This is a regular response without any tool calls.",
|
||||
single_tool_call_output=(
|
||||
"<|tool_calls_begin|><|tool_call_begin|>"
|
||||
'<steptml:invoke name="get_weather">'
|
||||
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
|
||||
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
|
||||
),
|
||||
parallel_tool_calls_output=(
|
||||
"<|tool_calls_begin|><|tool_call_begin|>"
|
||||
'<steptml:invoke name="get_weather">'
|
||||
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
|
||||
"</steptml:invoke><|tool_call_end|><|tool_sep|>"
|
||||
'<|tool_call_begin|><steptml:invoke name="get_time">'
|
||||
'<steptml:parameter name="timezone">Asia/Tokyo</steptml:parameter>'
|
||||
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
|
||||
),
|
||||
various_data_types_output=(
|
||||
"<|tool_calls_begin|><|tool_call_begin|>"
|
||||
'<steptml:invoke name="test_function">'
|
||||
'<steptml:parameter name="string_field">hello</steptml:parameter>'
|
||||
'<steptml:parameter name="int_field">42</steptml:parameter>'
|
||||
'<steptml:parameter name="float_field">3.14</steptml:parameter>'
|
||||
'<steptml:parameter name="bool_field">true</steptml:parameter>'
|
||||
'<steptml:parameter name="null_field">null</steptml:parameter>'
|
||||
'<steptml:parameter name="array_field">'
|
||||
'["a", "b", "c"]</steptml:parameter>'
|
||||
'<steptml:parameter name="object_field">'
|
||||
'{"nested": "value"}</steptml:parameter>'
|
||||
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
|
||||
),
|
||||
empty_arguments_output=(
|
||||
"<|tool_calls_begin|><|tool_call_begin|>"
|
||||
'<steptml:invoke name="refresh"></steptml:invoke>'
|
||||
"<|tool_call_end|><|tool_calls_end|>"
|
||||
),
|
||||
surrounding_text_output=(
|
||||
"Let me check the weather for you.\n\n"
|
||||
"<|tool_calls_begin|><|tool_call_begin|>"
|
||||
'<steptml:invoke name="get_weather">'
|
||||
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
|
||||
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>\n\n"
|
||||
"I'll get that information."
|
||||
),
|
||||
escaped_strings_output=(
|
||||
"<|tool_calls_begin|><|tool_call_begin|>"
|
||||
'<steptml:invoke name="test_function">'
|
||||
'<steptml:parameter name="quoted">He said "hello"</steptml:parameter>'
|
||||
'<steptml:parameter name="path">C:\\Users\\file.txt</steptml:parameter>'
|
||||
'<steptml:parameter name="newline">line1\nline2</steptml:parameter>'
|
||||
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
|
||||
),
|
||||
malformed_input_outputs=[
|
||||
(
|
||||
"<|tool_calls_begin|><|tool_call_begin|>"
|
||||
'<steptml:invoke name="func">'
|
||||
),
|
||||
(
|
||||
'<|tool_call_begin|><steptml:invoke name="func">'
|
||||
"</steptml:invoke><|tool_call_end|>"
|
||||
),
|
||||
],
|
||||
# Expected results
|
||||
single_tool_call_expected_name="get_weather",
|
||||
single_tool_call_expected_args={"city": "Tokyo"},
|
||||
parallel_tool_calls_count=2,
|
||||
parallel_tool_calls_names=["get_weather", "get_time"],
|
||||
# xfail markers
|
||||
xfail_nonstreaming={
|
||||
"test_single_tool_call_simple_args": (
|
||||
"Step3 parser non-streaming has bugs"
|
||||
),
|
||||
"test_parallel_tool_calls": ("Step3 parser non-streaming has bugs"),
|
||||
"test_various_data_types": "Step3 parser non-streaming has bugs",
|
||||
"test_empty_arguments": "Step3 parser non-streaming has bugs",
|
||||
"test_surrounding_text": "Step3 parser non-streaming has bugs",
|
||||
"test_escaped_strings": "Step3 parser non-streaming has bugs",
|
||||
},
|
||||
xfail_streaming={
|
||||
"test_parallel_tool_calls": (
|
||||
"Step3 parser has significant bugs in both streaming "
|
||||
"and non-streaming"
|
||||
),
|
||||
"test_streaming_reconstruction": (
|
||||
"Step3 parser non-streaming has bugs, so streaming "
|
||||
"doesn't match non-streaming"
|
||||
),
|
||||
},
|
||||
supports_typed_arguments=False,
|
||||
)
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from .utils import (
|
||||
MESSAGES_WITHOUT_TOOLS,
|
||||
SEED,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
ensure_system_prompt,
|
||||
@@ -27,6 +28,7 @@ async def test_chat_completion_without_tools(
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
@@ -47,6 +49,7 @@ async def test_chat_completion_without_tools(
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
stream=True,
|
||||
)
|
||||
chunks: list[str] = []
|
||||
@@ -97,6 +100,7 @@ async def test_chat_completion_with_tools(
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
@@ -118,6 +122,7 @@ async def test_chat_completion_with_tools(
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
tools=[WEATHER_TOOL],
|
||||
seed=SEED,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from .utils import (
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
SEED,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
)
|
||||
@@ -39,6 +40,7 @@ async def test_parallel_tool_calls(
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
@@ -76,6 +78,7 @@ async def test_parallel_tool_calls(
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -166,6 +169,7 @@ async def test_parallel_tool_calls_with_results(
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
@@ -184,6 +188,7 @@ async def test_parallel_tool_calls_with_results(
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -229,6 +234,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
@@ -247,6 +253,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
parallel_tool_calls=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from .utils import (
|
||||
MESSAGES_ASKING_FOR_TOOLS,
|
||||
MESSAGES_WITH_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
SEED,
|
||||
WEATHER_TOOL,
|
||||
)
|
||||
|
||||
@@ -27,6 +28,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
@@ -71,6 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=100,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -154,6 +157,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
@@ -171,6 +175,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
seed=SEED,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ def ensure_system_prompt(
|
||||
|
||||
# universal args for all models go here. also good if you need to test locally
|
||||
# and change type or KV cache quantization or something.
|
||||
SEED = 42
|
||||
|
||||
ARGS: list[str] = [
|
||||
"--enable-auto-tool-choice",
|
||||
"--max-model-len",
|
||||
|
||||
@@ -43,6 +43,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
KVCacheTensor,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
@@ -157,6 +158,24 @@ def new_chunked_local_attention_spec(
|
||||
)
|
||||
|
||||
|
||||
def new_mamba_spec(
|
||||
block_size=16,
|
||||
shapes=((2, 512), (3, 32, 32)),
|
||||
dtypes=(torch.float32, torch.float32),
|
||||
num_speculative_blocks=2,
|
||||
mamba_cache_mode="none",
|
||||
page_size_padded=None,
|
||||
):
|
||||
return MambaSpec(
|
||||
block_size=block_size,
|
||||
shapes=shapes,
|
||||
dtypes=dtypes,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_cache_mode=mamba_cache_mode,
|
||||
num_speculative_blocks=num_speculative_blocks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_none_hash(monkeypatch, hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
@@ -2010,6 +2029,28 @@ def test_auto_fit_max_model_len():
|
||||
assert vllm_config.model_config.max_model_len > 0
|
||||
|
||||
|
||||
def test_auto_fit_max_model_len_with_hybrid():
|
||||
"""Test that auto-fit works with hybrid KV cache specs."""
|
||||
# Create config with original_max_model_len=-1 to trigger auto-fit
|
||||
model_config = ModelConfig(max_model_len=8192)
|
||||
# Simulate the user passing -1 by setting original_max_model_len
|
||||
model_config.original_max_model_len = -1
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer
|
||||
gamma = 2
|
||||
kv_cache_specs = {
|
||||
"layer_1": new_mamba_spec(num_speculative_blocks=gamma),
|
||||
"layer_2": new_kv_cache_spec(),
|
||||
}
|
||||
|
||||
available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma)
|
||||
_kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs], [available_memory]
|
||||
)
|
||||
assert vllm_config.model_config.max_model_len == 1024
|
||||
|
||||
|
||||
def test_auto_fit_max_model_len_not_triggered():
|
||||
"""Test that auto-fit is not triggered when original_max_model_len is not -1."""
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
|
||||
@@ -12,7 +12,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.utils import ROCM_ENV_OVERRIDES, RemoteOpenAIServer
|
||||
from tests.v1.utils import check_request_balancing
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -27,6 +27,84 @@ TP_SIZE = int(os.getenv("TP_SIZE", "1"))
|
||||
NUM_NODES = 2
|
||||
|
||||
|
||||
async def _make_completion_request(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
) -> openai.types.Completion:
|
||||
"""Make a single completion request and validate the response.
|
||||
|
||||
Uses temperature=1.0 to ensure diverse outputs across concurrent
|
||||
requests for realistic load balancer testing.
|
||||
"""
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
assert completion.id is not None, (
|
||||
f"Expected non-None completion id. usage={completion.usage!r}"
|
||||
)
|
||||
assert completion.choices is not None and len(completion.choices) == 1, (
|
||||
f"Expected 1 choice, got "
|
||||
f"{len(completion.choices) if completion.choices else 'None'}"
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
# With temperature=1.0, the model may emit a stop token immediately,
|
||||
# producing empty text with finish_reason='stop'. This is valid
|
||||
# model behavior - the test's purpose is load balancing, not output
|
||||
# quality.
|
||||
assert choice.finish_reason in ("length", "stop"), (
|
||||
f"Expected finish_reason 'length' or 'stop', "
|
||||
f"got {choice.finish_reason!r}. text={choice.text!r}"
|
||||
)
|
||||
if choice.finish_reason == "length":
|
||||
assert len(choice.text) >= 1, (
|
||||
f"Expected non-empty text with finish_reason='length', got {choice.text!r}"
|
||||
)
|
||||
|
||||
assert completion.usage.prompt_tokens > 0, (
|
||||
f"Expected positive prompt_tokens, got {completion.usage.prompt_tokens}"
|
||||
)
|
||||
assert completion.usage.total_tokens > 0, (
|
||||
f"Expected positive total_tokens, got {completion.usage.total_tokens}"
|
||||
)
|
||||
return completion
|
||||
|
||||
|
||||
async def _run_request_bursts(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
num_requests: int = 200,
|
||||
num_bursts: int = 2,
|
||||
):
|
||||
"""Send multiple bursts of completion requests and validate all succeed."""
|
||||
for burst in range(num_bursts):
|
||||
all_tasks = []
|
||||
for _ in range(num_requests):
|
||||
all_tasks.append(
|
||||
asyncio.create_task(_make_completion_request(client, model_name))
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
results = await asyncio.gather(*all_tasks, return_exceptions=True)
|
||||
assert len(results) == num_requests, (
|
||||
f"Burst {burst}: expected {num_requests} results, got {len(results)}"
|
||||
)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
raise result
|
||||
|
||||
assert all(completion is not None for completion in results), (
|
||||
f"Burst {burst}: some completions were None"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
class MultinodeInternalLBServerManager:
|
||||
"""Manages multi-node data parallel vLLM server instances for internal
|
||||
load balancer testing using --headless mode."""
|
||||
@@ -108,6 +186,7 @@ class MultinodeInternalLBServerManager:
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
**ROCM_ENV_OVERRIDES,
|
||||
current_platform.device_control_env_var: ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(r, r + gpus_per_node)
|
||||
@@ -229,6 +308,7 @@ class APIOnlyServerManager:
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
**ROCM_ENV_OVERRIDES,
|
||||
# No GPUs needed for API-only server
|
||||
},
|
||||
)
|
||||
@@ -249,10 +329,11 @@ class APIOnlyServerManager:
|
||||
engines_server_args,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
**ROCM_ENV_OVERRIDES,
|
||||
current_platform.device_control_env_var: ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(self.dp_size * self.tp_size)
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
server.__enter__()
|
||||
@@ -395,58 +476,15 @@ async def test_multinode_dp_completion(
|
||||
servers: list[tuple[RemoteOpenAIServer, list[str]]],
|
||||
model_name: str,
|
||||
) -> None:
|
||||
async def make_request():
|
||||
completion = await client.completions.create(
|
||||
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0
|
||||
)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
# The exact number of tokens can vary slightly with temperature=1.0,
|
||||
# so we check for a reasonable minimum length.
|
||||
assert len(choice.text) >= 1
|
||||
# Finish reason might not always be 'length' if the model finishes early
|
||||
# or due to other reasons, especially with high temperature.
|
||||
# So, we'll accept 'length' or 'stop'.
|
||||
assert choice.finish_reason in ("length", "stop")
|
||||
|
||||
# Token counts can also vary, so we check they are positive.
|
||||
assert completion.usage.completion_tokens > 0
|
||||
assert completion.usage.prompt_tokens > 0
|
||||
assert completion.usage.total_tokens > 0
|
||||
return completion
|
||||
|
||||
# Test single request
|
||||
result = await make_request()
|
||||
result = await _make_completion_request(client, model_name)
|
||||
assert result is not None
|
||||
print("Multi-node internal LB handled single completion request successfully")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send multiple requests - internal LB should distribute across DP ranks
|
||||
num_requests = 200
|
||||
all_tasks = []
|
||||
for _ in range(num_requests):
|
||||
all_tasks.append(asyncio.create_task(make_request()))
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Second burst of requests
|
||||
all_tasks = []
|
||||
for _ in range(num_requests):
|
||||
all_tasks.append(asyncio.create_task(make_request()))
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
# Send multiple bursts - internal LB should distribute across DP ranks
|
||||
await _run_request_bursts(client, model_name)
|
||||
|
||||
_, server_args = servers[0]
|
||||
api_server_count = (
|
||||
@@ -570,59 +608,16 @@ async def test_api_only_multinode_dp_completion(
|
||||
) -> None:
|
||||
"""Test API-only server with all engines on separate headless server."""
|
||||
|
||||
async def make_request():
|
||||
completion = await api_only_client.completions.create(
|
||||
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0
|
||||
)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
# The exact number of tokens can vary slightly with temperature=1.0,
|
||||
# so we check for a reasonable minimum length.
|
||||
assert len(choice.text) >= 1
|
||||
# Finish reason might not always be 'length' if the model finishes
|
||||
# early or due to other reasons, especially with high temperature.
|
||||
# So, we'll accept 'length' or 'stop'.
|
||||
assert choice.finish_reason in ("length", "stop")
|
||||
|
||||
# Token counts can also vary, so we check they are positive.
|
||||
assert completion.usage.completion_tokens > 0
|
||||
assert completion.usage.prompt_tokens > 0
|
||||
assert completion.usage.total_tokens > 0
|
||||
return completion
|
||||
|
||||
# Test single request
|
||||
result = await make_request()
|
||||
result = await _make_completion_request(api_only_client, model_name)
|
||||
assert result is not None
|
||||
print("API-only server handled single completion request successfully")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send multiple requests - should be distributed across engines on
|
||||
# Send multiple bursts - should be distributed across engines on
|
||||
# headless server
|
||||
num_requests = 200
|
||||
all_tasks = []
|
||||
for _ in range(num_requests):
|
||||
all_tasks.append(asyncio.create_task(make_request()))
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Second burst of requests
|
||||
all_tasks = []
|
||||
for _ in range(num_requests):
|
||||
all_tasks.append(asyncio.create_task(make_request()))
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
await _run_request_bursts(api_only_client, model_name)
|
||||
|
||||
api_server, api_server_args = api_only_servers[0]
|
||||
api_server_count = (
|
||||
|
||||
@@ -29,6 +29,81 @@ else:
|
||||
from torch.library import impl_abstract as register_fake
|
||||
|
||||
|
||||
# scaled_fp4_quant functional + out variant for torch.compile buffer management
|
||||
|
||||
|
||||
def create_fp4_scale_tensor(
|
||||
m: int,
|
||||
n: int,
|
||||
device: torch.device,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Allocate the output scale tensor for scaled_fp4_quant.
|
||||
|
||||
When is_sf_swizzled_layout=True, we use rounded values to store the
|
||||
swizzled scales. Due to the requirement of the Tensor Core, the minimum
|
||||
tile is 128x4 for the scales. So, we first pad the scales to multiples
|
||||
of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are
|
||||
packed into an int32 for every 4 values. More:
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/
|
||||
#tcgen05-mma-scale-factor-b-layout-4x
|
||||
"""
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
block_size = 16
|
||||
if is_sf_swizzled_layout:
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
return torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
return torch.empty((m, n // block_size), device=device, dtype=torch.uint8)
|
||||
|
||||
|
||||
def create_fp4_output_tensors(
|
||||
m: int,
|
||||
n: int,
|
||||
device: torch.device,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate both output tensors for scaled_fp4_quant:
|
||||
(quantized_output, output_scale).
|
||||
|
||||
Must match the C++ scaled_fp4_quant_func allocation exactly.
|
||||
"""
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
|
||||
@register_fake("_C::scaled_fp4_quant")
|
||||
def _scaled_fp4_quant_fake(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
n = input.shape[-1]
|
||||
m = input.numel() // n
|
||||
return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout)
|
||||
|
||||
@register_fake("_C::scaled_fp4_quant.out")
|
||||
def _scaled_fp4_quant_out_fake(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool,
|
||||
*,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# page attention ops
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
|
||||
input = input.reshape(other_dims, input.shape[-1])
|
||||
m, n = input.shape
|
||||
block_size = 16
|
||||
device = input.device
|
||||
|
||||
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
||||
assert input.dtype in (torch.float16, torch.bfloat16), (
|
||||
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
|
||||
input, input_global_scale
|
||||
)
|
||||
else:
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
if is_sf_swizzled_layout:
|
||||
# We use the rounded values to store the swizzled values. Due to the
|
||||
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
||||
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
||||
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
|
||||
# Pre-allocate and call .out variant (same behavior as old in-place API)
|
||||
output, output_scale = create_fp4_output_tensors(
|
||||
m, n, input.device, is_sf_swizzled_layout
|
||||
)
|
||||
torch.ops._C.scaled_fp4_quant.out(
|
||||
input,
|
||||
input_global_scale,
|
||||
is_sf_swizzled_layout,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
|
||||
@@ -307,13 +307,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
num_submods = len(submod_names)
|
||||
num_artifacts = standalone_compile_artifacts.num_artifacts()
|
||||
|
||||
logger.info(
|
||||
"reconstructing serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
with functorch_ctx:
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
@@ -324,7 +317,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"reconstructed serializable fn from standalone compile artifacts"
|
||||
"reconstructed serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
|
||||
pass
|
||||
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=quant_result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=quant_result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
|
||||
@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=output_quant,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
@@ -112,7 +112,12 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||
]
|
||||
else:
|
||||
options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe
|
||||
if hasattr(torch.compiler, "skip_all_guards_unsafe"):
|
||||
# Torch 2.10+ provides skip_all_guards_unsafe
|
||||
options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe
|
||||
else:
|
||||
# Equivalent fallback for older PyTorch: skip all guards
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
compiled_ptr: Any = self.forward
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
|
||||
@@ -310,11 +310,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
reasoning_ended = (
|
||||
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
|
||||
if reasoning_parser
|
||||
else None
|
||||
)
|
||||
if not request.include_reasoning:
|
||||
reasoning_ended = True
|
||||
elif reasoning_parser:
|
||||
reasoning_ended = reasoning_parser.is_reasoning_end(
|
||||
prompt_token_ids or []
|
||||
)
|
||||
else:
|
||||
reasoning_ended = None
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
@@ -24,6 +24,14 @@ class PoolingBasicRequestMixin(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [start:pooling-common-extra-params]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
truncation_side: Literal["left", "right"] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Which side to truncate from when truncate_prompt_tokens is active. "
|
||||
"'right' keeps the first N tokens. "
|
||||
"'left' keeps the last N tokens."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
|
||||
@@ -32,6 +32,7 @@ class ClassificationCompletionRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -54,6 +55,7 @@ class ClassificationChatRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
|
||||
@@ -7,12 +7,12 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
EmbeddingRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -40,3 +40,24 @@ async def create_embedding(
|
||||
raise NotImplementedError("The model does not support Embeddings API")
|
||||
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/embed",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_cohere_embedding(
|
||||
request: CohereEmbedRequest,
|
||||
raw_request: Request,
|
||||
):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Embeddings API")
|
||||
|
||||
return await handler(request, raw_request)
|
||||
|
||||
@@ -1,14 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import torch
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedInput,
|
||||
CohereEmbedRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.inputs.data import ProcessorInputs, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.renderers import merge_kwargs
|
||||
from vllm.utils.collection_utils import chunk_list
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbedIOProcessor(PoolingIOProcessor):
|
||||
@@ -21,16 +44,45 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
self.pooler_config = self.model_config.pooler_config
|
||||
self.enable_chunked_processing = self.pooler_config.enable_chunked_processing
|
||||
|
||||
# Load task instructions from HF config or sentence-transformers config
|
||||
self.task_instructions: dict[str, str] | None = self._load_task_instructions(
|
||||
self.model_config.hf_config
|
||||
) or self._load_st_prompts(self.model_config.model, self.model_config.revision)
|
||||
if self.task_instructions:
|
||||
logger.info(
|
||||
"Loaded prompt prefixes for input_type: %s",
|
||||
list(self.task_instructions.keys()),
|
||||
)
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
self._pre_process_cohere_online(ctx)
|
||||
else:
|
||||
super().pre_process_online(ctx)
|
||||
|
||||
if self.enable_chunked_processing:
|
||||
self._pre_process_chunked(ctx)
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
self._enforce_cohere_max_tokens(ctx)
|
||||
return super().post_process_online(ctx)
|
||||
|
||||
self._post_process_chunked(ctx)
|
||||
self._enforce_cohere_max_tokens(ctx)
|
||||
|
||||
#################################################################
|
||||
# Long Text Embedding with Chunked Processing
|
||||
# PTAL: examples/pooling/embed/openai_embedding_long_text
|
||||
#################################################################
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
super().pre_process_online(ctx)
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
return None
|
||||
|
||||
def _pre_process_chunked(self, ctx: PoolingServeContext) -> None:
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
@@ -61,18 +113,10 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
|
||||
ctx.engine_prompts = chunked_engine_prompts
|
||||
ctx.prompt_request_ids = prompt_request_ids
|
||||
|
||||
return None
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
return super().post_process_online(ctx)
|
||||
|
||||
def _post_process_chunked(self, ctx: PoolingServeContext) -> None:
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
@@ -195,4 +239,245 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
raise ValueError(f"Result not found for prompt {prompt_idx}")
|
||||
|
||||
ctx.final_res_batch = final_res_batch
|
||||
|
||||
return None
|
||||
|
||||
#################################################################
|
||||
# Cohere Request Preprocessing & Postprocessing
|
||||
#################################################################
|
||||
|
||||
@staticmethod
|
||||
def _load_task_instructions(hf_config: Any) -> dict[str, str] | None:
|
||||
"""Extract ``task_instructions`` from the HF model config."""
|
||||
ti = getattr(hf_config, "task_instructions", None)
|
||||
if not isinstance(ti, dict) or not ti:
|
||||
return None
|
||||
return {k: v for k, v in ti.items() if isinstance(v, str)}
|
||||
|
||||
@staticmethod
|
||||
def _load_st_prompts(
|
||||
model: str | Any,
|
||||
revision: str | None,
|
||||
) -> dict[str, str] | None:
|
||||
"""Load ``task_instructions`` from ``config_sentence_transformers.json``."""
|
||||
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
|
||||
|
||||
try:
|
||||
cfg = get_hf_file_to_dict(
|
||||
"config_sentence_transformers.json", str(model), revision
|
||||
)
|
||||
except (ValueError, OSError):
|
||||
return None
|
||||
|
||||
if cfg is None:
|
||||
return None
|
||||
prompts = cfg.get("prompts")
|
||||
if not isinstance(prompts, dict) or not prompts:
|
||||
return None
|
||||
return {k: v for k, v in prompts.items() if isinstance(v, str)}
|
||||
|
||||
@staticmethod
|
||||
def _mixed_input_to_messages(
|
||||
inp: CohereEmbedInput,
|
||||
*,
|
||||
task_prefix: str | None = None,
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Build chat messages from a mixed text+image input.
|
||||
|
||||
When *task_prefix* is given, it is prepended to each text part.
|
||||
"""
|
||||
parts: list[ChatCompletionContentPartParam] = []
|
||||
for item in inp.content:
|
||||
if item.type == "text" and item.text is not None:
|
||||
text = task_prefix + item.text if task_prefix else item.text
|
||||
parts.append(ChatCompletionContentPartTextParam(type="text", text=text))
|
||||
elif item.type == "image_url" and item.image_url is not None:
|
||||
parts.append(
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url=ImageURL(url=item.image_url["url"]),
|
||||
)
|
||||
)
|
||||
return [CustomChatCompletionMessageParam(role="user", content=parts)]
|
||||
|
||||
@staticmethod
|
||||
def _check_cohere_max_tokens(
|
||||
outputs: list[PoolingRequestOutput],
|
||||
max_tokens_check: int | None,
|
||||
) -> None:
|
||||
"""Raise if any output exceeds *max_tokens_check* tokens.
|
||||
|
||||
Used to enforce ``truncate=NONE`` with an explicit ``max_tokens``:
|
||||
the pipeline runs without truncation and we reject afterwards.
|
||||
"""
|
||||
if max_tokens_check is None:
|
||||
return
|
||||
for out in outputs:
|
||||
n = len(out.prompt_token_ids)
|
||||
if n > max_tokens_check:
|
||||
raise ValueError(
|
||||
f"Input of {n} tokens exceeds max_tokens={max_tokens_check} "
|
||||
"with truncate=NONE. Set truncate to END or START to "
|
||||
"allow truncation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_cohere_truncation(
|
||||
request: CohereEmbedRequest,
|
||||
) -> tuple[int | None, Literal["left", "right"] | None]:
|
||||
"""Return ``(truncate_prompt_tokens, truncation_side)``."""
|
||||
if request.truncate == "NONE":
|
||||
return None, None
|
||||
if request.truncate == "START":
|
||||
tokens = request.max_tokens if request.max_tokens is not None else -1
|
||||
return tokens, "left"
|
||||
if request.max_tokens is not None:
|
||||
return request.max_tokens, None
|
||||
return -1, None
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
if isinstance(request, CohereEmbedRequest):
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=request.output_dimension,
|
||||
)
|
||||
return super().create_pooling_params(request)
|
||||
|
||||
def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None:
|
||||
"""Convert a ``CohereEmbedRequest`` into engine prompts.
|
||||
|
||||
For texts, a single batched completion request path is used.
|
||||
For images and mixed inputs, conversations are batch-rendered
|
||||
through the chat template in one ``render_chat`` call.
|
||||
"""
|
||||
request = ctx.request
|
||||
assert isinstance(request, CohereEmbedRequest)
|
||||
|
||||
if request.texts is None and request.images is None and request.inputs is None:
|
||||
raise ValueError("One of texts, images, or inputs must be provided")
|
||||
|
||||
truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation(
|
||||
request
|
||||
)
|
||||
input_type = request.input_type
|
||||
self._validate_input_type(input_type)
|
||||
|
||||
if request.images is not None:
|
||||
all_messages: list[list[ChatCompletionMessageParam]] = [
|
||||
[
|
||||
CustomChatCompletionMessageParam(
|
||||
role="user",
|
||||
content=[{"type": "image_url", "image_url": {"url": uri}}],
|
||||
)
|
||||
]
|
||||
for uri in request.images
|
||||
]
|
||||
ctx.engine_prompts = self._batch_render_chat(
|
||||
request, all_messages, truncate_prompt_tokens, truncation_side
|
||||
)
|
||||
|
||||
elif request.inputs is not None:
|
||||
task_prefix = self._get_task_instruction_prefix(input_type)
|
||||
all_messages = [
|
||||
self._mixed_input_to_messages(inp, task_prefix=task_prefix)
|
||||
for inp in request.inputs
|
||||
]
|
||||
ctx.engine_prompts = self._batch_render_chat(
|
||||
request, all_messages, truncate_prompt_tokens, truncation_side
|
||||
)
|
||||
|
||||
else:
|
||||
prefixed = self._apply_task_instruction(request.texts or [], input_type)
|
||||
proxy = EmbeddingCompletionRequest(
|
||||
model=request.model,
|
||||
input=prefixed,
|
||||
dimensions=request.output_dimension,
|
||||
encoding_format="float",
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
ctx.engine_prompts = self._preprocess_completion_online(
|
||||
proxy, prompt_input=proxy.input, prompt_embeds=None
|
||||
)
|
||||
|
||||
def _batch_render_chat(
|
||||
self,
|
||||
request: CohereEmbedRequest,
|
||||
all_messages: Sequence[list[ChatCompletionMessageParam]],
|
||||
truncate_prompt_tokens: int | None,
|
||||
truncation_side: Literal["left", "right"] | None,
|
||||
) -> list[ProcessorInputs]:
|
||||
"""Batch-render multiple conversations through the chat template."""
|
||||
if not all_messages:
|
||||
return []
|
||||
|
||||
proxy = EmbeddingChatRequest(
|
||||
model=request.model,
|
||||
messages=list(all_messages[0]),
|
||||
dimensions=request.output_dimension,
|
||||
encoding_format="float",
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
|
||||
renderer = self.renderer
|
||||
mm_config = self.model_config.multimodal_config
|
||||
|
||||
tok_params = proxy.build_tok_params(self.model_config)
|
||||
chat_params = proxy.build_chat_params(
|
||||
self.chat_template,
|
||||
self.chat_template_content_format,
|
||||
).with_defaults(
|
||||
merge_kwargs(
|
||||
None,
|
||||
dict(
|
||||
tools=None,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
),
|
||||
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
|
||||
)
|
||||
|
||||
_, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params)
|
||||
return engine_prompts
|
||||
|
||||
def _validate_input_type(self, input_type: str | None) -> None:
|
||||
"""Raise if *input_type* is not supported by this model."""
|
||||
if input_type is None:
|
||||
return
|
||||
if self.task_instructions is None:
|
||||
raise ValueError(
|
||||
f"Unsupported input_type {input_type!r}. "
|
||||
"This model does not define any input_type task instructions."
|
||||
)
|
||||
if input_type not in self.task_instructions:
|
||||
supported = ", ".join(sorted(self.task_instructions))
|
||||
raise ValueError(
|
||||
f"Unsupported input_type {input_type!r}. Supported values: {supported}"
|
||||
)
|
||||
|
||||
def _apply_task_instruction(
|
||||
self,
|
||||
texts: list[str],
|
||||
input_type: str | None,
|
||||
) -> list[str]:
|
||||
"""Prepend the task-instruction prefix for *input_type*.
|
||||
|
||||
Returns *texts* unchanged when no matching prefix is configured.
|
||||
"""
|
||||
prefix = self._get_task_instruction_prefix(input_type)
|
||||
if not prefix:
|
||||
return texts
|
||||
return [prefix + t for t in texts]
|
||||
|
||||
def _get_task_instruction_prefix(self, input_type: str | None) -> str | None:
|
||||
"""Return the task-instruction prefix for *input_type*, or ``None``."""
|
||||
if not self.task_instructions or input_type is None:
|
||||
return None
|
||||
return self.task_instructions.get(input_type) or None
|
||||
|
||||
def _enforce_cohere_max_tokens(self, ctx: PoolingServeContext) -> None:
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
request = ctx.request
|
||||
if request.truncate == "NONE" and request.max_tokens is not None:
|
||||
self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens)
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import TypeAlias
|
||||
"""Embedding API protocol models for OpenAI and Cohere formats.
|
||||
|
||||
from pydantic import Field
|
||||
OpenAI: https://platform.openai.com/docs/api-reference/embeddings
|
||||
Cohere: https://docs.cohere.com/reference/embed
|
||||
"""
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
import struct
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
@@ -17,6 +27,10 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI /v1/embeddings — request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_max_total_output_tokens(
|
||||
model_config: ModelConfig,
|
||||
@@ -50,6 +64,7 @@ class EmbeddingCompletionRequest(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -79,6 +94,7 @@ class EmbeddingChatRequest(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -96,6 +112,11 @@ class EmbeddingChatRequest(
|
||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI /v1/embeddings — response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
@@ -106,7 +127,7 @@ class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
model: str | None = None
|
||||
data: list[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
@@ -115,3 +136,146 @@ class EmbeddingBytesResponse(OpenAIBaseModel):
|
||||
content: list[bytes]
|
||||
headers: dict[str, str] | None = None
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere /v2/embed — request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CohereEmbeddingType = Literal[
|
||||
"float",
|
||||
"binary",
|
||||
"ubinary",
|
||||
"base64",
|
||||
]
|
||||
CohereTruncate = Literal["NONE", "START", "END"]
|
||||
|
||||
|
||||
class CohereEmbedContent(BaseModel):
|
||||
type: Literal["text", "image_url"]
|
||||
text: str | None = None
|
||||
image_url: dict[str, str] | None = None
|
||||
|
||||
|
||||
class CohereEmbedInput(BaseModel):
|
||||
content: list[CohereEmbedContent]
|
||||
|
||||
|
||||
class CohereEmbedRequest(BaseModel):
|
||||
model: str | None = None
|
||||
input_type: str | None = None
|
||||
texts: list[str] | None = None
|
||||
images: list[str] | None = None
|
||||
inputs: list[CohereEmbedInput] | None = None
|
||||
output_dimension: int | None = None
|
||||
embedding_types: list[CohereEmbeddingType] | None = None
|
||||
truncate: CohereTruncate = "END"
|
||||
max_tokens: int | None = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere /v2/embed — response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CohereApiVersion(BaseModel):
|
||||
version: str = "2"
|
||||
|
||||
|
||||
class CohereBilledUnits(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
image_tokens: int | None = None
|
||||
|
||||
|
||||
class CohereMeta(BaseModel):
|
||||
api_version: CohereApiVersion = Field(default_factory=CohereApiVersion)
|
||||
billed_units: CohereBilledUnits | None = None
|
||||
|
||||
|
||||
class CohereEmbedByTypeEmbeddings(BaseModel):
|
||||
# The field name ``float`` shadows the builtin type, so the annotation
|
||||
# must use ``builtins.float`` to avoid a self-referential type error.
|
||||
float: list[list[builtins.float]] | None = None
|
||||
binary: list[list[int]] | None = None
|
||||
ubinary: list[list[int]] | None = None
|
||||
base64: list[str] | None = None
|
||||
|
||||
|
||||
class CohereEmbedResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
embeddings: CohereEmbedByTypeEmbeddings
|
||||
texts: list[str] | None = None
|
||||
meta: CohereMeta | None = None
|
||||
response_type: Literal["embeddings_by_type"] = "embeddings_by_type"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere embedding type conversion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UNSIGNED_TO_SIGNED_DIFF = 1 << 7 # 128
|
||||
|
||||
|
||||
def _pack_binary_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
signed: bool,
|
||||
) -> list[list[int]]:
|
||||
"""Bit-pack float embeddings: positive -> 1, negative -> 0.
|
||||
|
||||
Each bit is shifted left by ``7 - idx%8``, and every 8 bits are packed
|
||||
into one byte.
|
||||
"""
|
||||
result: list[list[int]] = []
|
||||
for embedding in float_embeddings:
|
||||
dim = len(embedding)
|
||||
if dim % 8 != 0:
|
||||
raise ValueError(
|
||||
"Embedding dimension must be a multiple of 8 for binary "
|
||||
f"embedding types, but got {dim}."
|
||||
)
|
||||
packed_len = dim // 8
|
||||
packed: list[int] = []
|
||||
byte_val = 0
|
||||
for idx, value in enumerate(embedding):
|
||||
bit = 1 if value >= 0 else 0
|
||||
byte_val += bit << (7 - idx % 8)
|
||||
if (idx + 1) % 8 == 0:
|
||||
if signed:
|
||||
byte_val -= _UNSIGNED_TO_SIGNED_DIFF
|
||||
packed.append(byte_val)
|
||||
byte_val = 0
|
||||
assert len(packed) == packed_len
|
||||
result.append(packed)
|
||||
return result
|
||||
|
||||
|
||||
def _encode_base64_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
) -> list[str]:
|
||||
"""Encode float embeddings as base64 (little-endian float32)."""
|
||||
result: list[str] = []
|
||||
for embedding in float_embeddings:
|
||||
buf = struct.pack(f"<{len(embedding)}f", *embedding)
|
||||
result.append(base64.b64encode(buf).decode("utf-8"))
|
||||
return result
|
||||
|
||||
|
||||
def build_typed_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
embedding_types: Sequence[str],
|
||||
) -> CohereEmbedByTypeEmbeddings:
|
||||
"""Convert float embeddings to all requested Cohere embedding types."""
|
||||
result = CohereEmbedByTypeEmbeddings()
|
||||
|
||||
for emb_type in embedding_types:
|
||||
if emb_type == "float":
|
||||
result.float = float_embeddings
|
||||
elif emb_type == "binary":
|
||||
result.binary = _pack_binary_embeddings(float_embeddings, signed=True)
|
||||
elif emb_type == "ubinary":
|
||||
result.ubinary = _pack_binary_embeddings(float_embeddings, signed=False)
|
||||
elif emb_type == "base64":
|
||||
result.base64 = _encode_base64_embeddings(float_embeddings)
|
||||
|
||||
return result
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Literal, TypeAlias, cast
|
||||
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@@ -14,10 +14,15 @@ from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereBilledUnits,
|
||||
CohereEmbedRequest,
|
||||
CohereEmbedResponse,
|
||||
CohereMeta,
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
build_typed_embeddings,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.entrypoints.pooling.utils import (
|
||||
@@ -26,24 +31,23 @@ from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_output_float,
|
||||
get_json_response_cls,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
JSONResponseCLS = get_json_response_cls()
|
||||
|
||||
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
|
||||
|
||||
|
||||
class ServingEmbedding(PoolingServing):
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
"""Embedding API supporting both OpenAI and Cohere formats."""
|
||||
|
||||
request_id_prefix = "embd"
|
||||
io_processor: EmbedIOProcessor
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
@@ -58,6 +62,14 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> Response:
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
return self._build_cohere_response_from_ctx(ctx)
|
||||
return await self._build_openai_response(ctx)
|
||||
|
||||
async def _build_openai_response(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> JSONResponse | StreamingResponse:
|
||||
@@ -66,7 +78,7 @@ class ServingEmbedding(PoolingServing):
|
||||
endianness = ctx.request.endianness
|
||||
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return self._request_output_to_embed_json_response(
|
||||
return self._openai_json_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
@@ -77,7 +89,7 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
|
||||
if encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return self._request_output_to_to_embed_bytes_response(
|
||||
return self._openai_bytes_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
@@ -89,7 +101,7 @@ class ServingEmbedding(PoolingServing):
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
def _request_output_to_embed_json_response(
|
||||
def _openai_json_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
@@ -139,7 +151,7 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
return JSONResponseCLS(content=response.model_dump())
|
||||
|
||||
def _request_output_to_to_embed_bytes_response(
|
||||
def _openai_bytes_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
@@ -177,3 +189,33 @@ class ServingEmbedding(PoolingServing):
|
||||
headers=response.headers,
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_cohere_response_from_ctx(
|
||||
ctx: PoolingServeContext,
|
||||
) -> JSONResponse:
|
||||
request = ctx.request
|
||||
assert isinstance(request, CohereEmbedRequest)
|
||||
|
||||
all_floats = [encode_pooling_output_float(out) for out in ctx.final_res_batch]
|
||||
total_tokens = sum(len(out.prompt_token_ids) for out in ctx.final_res_batch)
|
||||
|
||||
image_tokens = total_tokens if request.images is not None else 0
|
||||
texts_echo = request.texts
|
||||
|
||||
embedding_types = request.embedding_types or ["float"]
|
||||
embeddings_obj = build_typed_embeddings(all_floats, embedding_types)
|
||||
|
||||
input_tokens = total_tokens - image_tokens
|
||||
response = CohereEmbedResponse(
|
||||
id=ctx.request_id,
|
||||
embeddings=embeddings_obj,
|
||||
texts=texts_echo,
|
||||
meta=CohereMeta(
|
||||
billed_units=CohereBilledUnits(
|
||||
input_tokens=input_tokens,
|
||||
image_tokens=image_tokens,
|
||||
),
|
||||
),
|
||||
)
|
||||
return JSONResponse(content=response.model_dump(exclude_none=True))
|
||||
|
||||
@@ -36,6 +36,7 @@ class PoolingCompletionRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -61,6 +62,7 @@ class PoolingChatRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -88,6 +90,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
max_total_tokens_param="max_model_len",
|
||||
|
||||
@@ -30,6 +30,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
@@ -105,6 +106,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
@@ -50,6 +51,7 @@ AnyPoolingRequest: TypeAlias = (
|
||||
| IOProcessorRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| CohereEmbedRequest
|
||||
)
|
||||
|
||||
AnyPoolingResponse: TypeAlias = (
|
||||
|
||||
15
vllm/envs.py
15
vllm/envs.py
@@ -296,6 +296,16 @@ def use_aot_compile() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def use_mega_aot_artifact():
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
default_value = (
|
||||
"1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0"
|
||||
)
|
||||
|
||||
return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1"
|
||||
|
||||
|
||||
def env_with_choices(
|
||||
env_name: str,
|
||||
default: str | None,
|
||||
@@ -616,10 +626,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Enable loading compiled models directly from cached standalone compile artifacts
|
||||
# without re-splitting graph modules. This reduces overhead during model
|
||||
# loading by using reconstruct_serializable_fn_from_mega_artifact.
|
||||
"VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get(
|
||||
"VLLM_USE_MEGA_AOT_ARTIFACT", "0"
|
||||
)
|
||||
== "1",
|
||||
"VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact,
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@@ -155,9 +156,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
if type(source_layer) is ColumnParallelLinear:
|
||||
if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
|
||||
return True
|
||||
if type(source_layer) is MergedColumnParallelLinear:
|
||||
if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear):
|
||||
if len(packed_modules_list) != 1:
|
||||
return False
|
||||
# Exclude layers with 3+ output sizes - those are handled by
|
||||
@@ -606,7 +607,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA(
|
||||
) -> bool:
|
||||
# Support MergedColumnParallelLinear with 3 or more slices
|
||||
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
|
||||
if type(source_layer) is not MergedColumnParallelLinear:
|
||||
if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear):
|
||||
return False
|
||||
|
||||
# If packed_modules_list has 3+ items, use this class
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
@@ -55,7 +56,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ReplicatedLinear
|
||||
return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear)
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -89,7 +90,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is RowParallelLinear
|
||||
return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)
|
||||
|
||||
|
||||
# The following layer is based on the tensor parallelism strategy given in
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -132,7 +133,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
||||
@@ -22,10 +22,11 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
|
||||
|
||||
def get_oot_class_by_name(class_name: str) -> type | None:
|
||||
def maybe_get_oot_by_class(class_type: type) -> type:
|
||||
class_name = class_type.__name__
|
||||
if class_name in op_registry_oot:
|
||||
return op_registry_oot[class_name]
|
||||
return None
|
||||
return class_type
|
||||
|
||||
|
||||
class PluggableLayer(nn.Module):
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
|
||||
from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
@@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp):
|
||||
cu_seqlens: np.ndarray,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor | None:
|
||||
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
|
||||
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
|
||||
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
@@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp):
|
||||
tp_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
|
||||
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
|
||||
attn_backend, cu_seqlens, hidden_size, tp_size, device
|
||||
)
|
||||
|
||||
@@ -659,6 +659,13 @@ def run_cutlass_moe_fp4(
|
||||
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
"""CUTLASS FP4 fused MoE expert implementation."""
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fuse activation scales into w_scale_2 in-place so that
|
||||
# g1/g2_alphas (which reference the same tensor) stay in sync
|
||||
# when EPLB rearranges the parameter.
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp8Dynamic,
|
||||
kMxfp8Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -67,11 +69,54 @@ class TrtLlmFp8ExpertsBase:
|
||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) in [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]:
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
@@ -113,9 +158,10 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 block."""
|
||||
"""Supports Fp8 block and MXFP8."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@@ -159,6 +205,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe import Fp8QuantizationType
|
||||
|
||||
# Pack topk_ids and topk_weights into single tensor
|
||||
# Format: (expert_id << 16) | (weight_bf16.view(int16))
|
||||
@@ -175,6 +222,16 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
|
||||
assert a1q_scale is not None
|
||||
|
||||
is_mxfp8 = self.quant_config.block_shape == [1, 32]
|
||||
if is_mxfp8:
|
||||
fp8_quant_type = Fp8QuantizationType.MxFp8
|
||||
use_shuffled_weight = True
|
||||
hidden_states_scale = a1q_scale
|
||||
else:
|
||||
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
|
||||
use_shuffled_weight = False
|
||||
hidden_states_scale = a1q_scale.t().contiguous()
|
||||
|
||||
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
|
||||
# output tensor in-place so we need to manually copy the result to the
|
||||
# output tensor
|
||||
@@ -183,7 +240,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
topk_ids=packed_topk_ids,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
@@ -197,8 +254,9 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
use_shuffled_weight=False,
|
||||
use_shuffled_weight=use_shuffled_weight,
|
||||
weight_layout=0,
|
||||
fp8_quantization_type=fp8_quant_type,
|
||||
# output=output,
|
||||
)
|
||||
output.copy_(result)
|
||||
@@ -240,10 +298,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@@ -256,7 +315,10 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
if (weight_key, activation_key) in [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kMxfp8Static, kMxfp8Dynamic),
|
||||
]:
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
@@ -274,7 +336,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
def _apply_per_block(
|
||||
def _apply_block_scale(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -291,32 +353,38 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe import Fp8QuantizationType
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
assert activation == MoEActivation.SILU
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape in [[128, 128], [1, 32]]
|
||||
# Kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# Kernel requires transposed hidden state scales
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
a1q_scale_t = a1q_scale.t().contiguous()
|
||||
is_mxfp8 = self.quant_config.block_shape == [1, 32]
|
||||
if is_mxfp8:
|
||||
fp8_quant_type = Fp8QuantizationType.MxFp8
|
||||
use_shuffled_weight = True
|
||||
hidden_states_scale = a1q_scale
|
||||
else:
|
||||
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
|
||||
use_shuffled_weight = False
|
||||
hidden_states_scale = a1q_scale.t().contiguous()
|
||||
|
||||
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale_t,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
@@ -330,7 +398,8 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
use_shuffled_weight=use_shuffled_weight,
|
||||
fp8_quantization_type=fp8_quant_type,
|
||||
)
|
||||
|
||||
def _apply_per_tensor(
|
||||
@@ -409,7 +478,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.block_shape is not None:
|
||||
return self._apply_per_block(
|
||||
return self._apply_block_scale(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@@ -441,6 +510,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only per-block and per-tensor quantization are supported in "
|
||||
f"{self.__class__.__name__}."
|
||||
"Only per-block, per-tensor, and MXFP8 quantization are "
|
||||
f"supported in {self.__class__.__name__}."
|
||||
)
|
||||
|
||||
@@ -56,10 +56,25 @@ class TrtLlmNvFp4ExpertsBase:
|
||||
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
self.g1_scale_c = (
|
||||
torch.ones_like(self.quant_config.a1_gscale)
|
||||
* self.quant_config.a2_gscale
|
||||
)
|
||||
self.g1_scale_c = self.quant_config.a2_gscale.clone()
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
# Recompute g1_scale_c since g1_alphas was just fused in-place.
|
||||
# Register as a layer parameter so EPLB rearranges it alongside
|
||||
# other expert weights.
|
||||
assert self.quant_config.g1_alphas is not None
|
||||
assert self.quant_config.a2_gscale is not None
|
||||
if self.moe_config.is_act_and_mul:
|
||||
g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
g1_scale_c = self.quant_config.a2_gscale.clone()
|
||||
layer.register_parameter(
|
||||
"g1_scale_c",
|
||||
torch.nn.Parameter(g1_scale_c, requires_grad=False),
|
||||
)
|
||||
self.g1_scale_c = layer.g1_scale_c
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
|
||||
@@ -49,6 +49,10 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
)
|
||||
self.out_dtype = moe_config.in_dtype
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user